Skip to content

Commit

Permalink
FIX: fixed overlapping of progress-bars with tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Pratyush-exe committed Jan 16, 2024
1 parent 1dcdc8d commit 004ef1c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 20 deletions.
60 changes: 41 additions & 19 deletions deepeval/callbacks/huggingface/deepeval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from rich.console import Console
from rich.table import Table
from rich.live import Live
from rich.columns import Columns
from rich.progress import Progress, BarColumn, \
SpinnerColumn, TextColumn

from transformers import TrainerCallback, \
ProgressCallback, Trainer, \
TrainingArguments, TrainerState, TrainerControl

from deepeval.metrics import BaseMetric
from deepeval.dataset import EvaluationDataset
from deepeval.evaluate import execute_test
from deepeval.progress_context import progress_context


class DeepEvalCallback(TrainerCallback):
Expand Down Expand Up @@ -46,6 +49,7 @@ def __init__(
self.aggregation_method = aggregation_method
self.trainer = trainer

self.train_bar_started = False
self.epoch_counter = 0
self.deepeval_metric_history = []
self._initiate_rich_console()
Expand Down Expand Up @@ -126,7 +130,7 @@ def on_epoch_end(self,
"""
Event triggered at the end of each training epoch.
"""
self.progress.update(1)

control.should_log = True


Expand All @@ -139,24 +143,35 @@ def on_log(self,
"""
Event triggered after logging the last logs.
"""

if not self.train_bar_started:
self.progress.start()
self.train_bar_started = True

if (
self.show_table
and (self.epoch_counter % self.show_table_every == 0)
self.show_table
and len(state.log_history) <= self.trainer.args.num_train_epochs
):
with progress_context("Evaluating testcases..."):
self.progress.update(self.progress_task, advance=1)
if self.epoch_counter % self.show_table_every == 0:
self.spinner.reset(self.spinner_task, description="[STATUS] Evaluating test-cases (might take up few minutes) ...")

scores = self._calculate_metric_scores()
self.deepeval_metric_history.append(scores)
self.deepeval_metric_history[-1].update(state.log_history[-1])

def generate_table():
new_table = Table()
for key in self.deepeval_metric_history[-1].keys():
new_table.add_column(key)
for row in self.deepeval_metric_history:
new_table.add_row(*[str(value) for value in row.values()])
return new_table
self.live.update(generate_table(), refresh=True)

self.spinner.reset(self.spinner_task, description="[STATUS] Training in Progress ...")

def generate_table():
new_table = Table()
cols = Columns([new_table, self.spinner, self.progress], equal=True, expand=True)
for key in self.deepeval_metric_history[-1].keys():
new_table.add_column(key)
for row in self.deepeval_metric_history:
new_table.add_row(*[str(value) for value in row.values()])
return cols

self.live.update(generate_table(), refresh=True)

def on_train_end(self,
args: TrainingArguments,
Expand All @@ -167,7 +182,7 @@ def on_train_end(self,
"""
Event triggered at the end of model training.
"""
self.progress.close()
self.progress.stop()

def on_train_begin(self,
args: TrainingArguments,
Expand All @@ -178,9 +193,16 @@ def on_train_begin(self,
"""
Event triggered at the begining of model training.
"""
self.progress = tqdm(
total=self.trainer.args.num_train_epochs,
desc="Epochs"
self.progress = Progress(
TextColumn("{task.description} [progress.percentage][{task.percentage:>3.1f}%]:", justify="right"),
BarColumn(),
TextColumn("[green][ {task.completed}/{task.total} epochs ]", justify="right"),
)
self.progress_task = self.progress.add_task("Train Progress", total=self.trainer.args.num_train_epochs)


self.spinner = Progress(
SpinnerColumn(),
TextColumn("{task.description}", justify="right"),
transient=True
)
self.spinner_task = self.spinner.add_task("[STATUS] Training in Progress ...", total=9999)
4 changes: 3 additions & 1 deletion deepeval/callbacks/huggingface/deepeval_harness_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ class DeepEvalHarnessCallback(TrainerCallback):

def __init__(self, experiments: Union[BaseEvaluationExperiment, List[BaseEvaluationExperiment]]):
super().__init__()
self.experiments = experiments
self.experiments = experiments

raise NotImplementedError("DeepEvalHarnessCallback is WIP")

0 comments on commit 004ef1c

Please sign in to comment.