From 619dc652e1993ac7f513c66023b0d8a79fcea10f Mon Sep 17 00:00:00 2001 From: manasMauryax Date: Wed, 25 Sep 2024 13:55:50 +0200 Subject: [PATCH] feat: add tracking of distinct losses --- src/modalities/loss_functions.py | 24 +++++++++++++++++++++++- src/modalities/trainer.py | 21 ++++++++++++++++++++- 2 files changed, 43 insertions(+), 2 deletions(-) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 6e7f8501..c90dde33 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -54,14 +54,36 @@ def __init__( self.groups = [(loss_func, weight) for loss_func, weight in zip(losses, corrsp_weights)] + self.cumulated_individual_losses = None + # variable storing each loss, + # summed over local batches, + # separately. + + self.reset_cumulated_individual_losses() + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: device = forward_batch.predictions[list(forward_batch.predictions.keys())[0]].device total_loss = torch.tensor(0, dtype=torch.float, device=device) - for loss_func, weight in self.groups: + for ind, (loss_func, weight) in enumerate(self.groups): loss = loss_func(forward_batch) + self.cumulated_individual_losses[ind] += loss total_loss += weight * loss return total_loss + def reset_cumulated_individual_losses( + self, + ) -> None: + """Initializes and resets the variable + accumulating each loss separately. + + Called first when the class is initialized, and then + after every logging step in trainer.py. + """ + if torch.cuda.is_available(): + self.cumulated_individual_losses = torch.zeros(len(self.groups)).to(torch.device("cuda")) + else: + self.cumulated_individual_losses = torch.zeros(len(self.groups)).to("cpu") + class CLMCrossEntropyLoss(Loss): def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"): diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index e15e47c0..798b569a 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -12,7 +12,7 @@ from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher -from modalities.loss_functions import Loss +from modalities.loss_functions import Loss, MultipleFunctionsLoss from modalities.models.model import model_predict_batch from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -259,6 +259,25 @@ def train( "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), } + # If there are multiple loss functions being used, + # this block computes and logs all the individual + # losses, averaged over the global batch size. + if isinstance(loss_fun, MultipleFunctionsLoss): + global_batch_size = Reducer.reduce( + tensor=cumulated_losses[-1], operation=dist.ReduceOp.SUM, post_processing_fun=None + ) + reduced_individual_losses = Reducer.reduce( + tensor=loss_fun.cumulated_individual_losses, + operation=dist.ReduceOp.SUM, + post_processing_fun=lambda t: torch.stack( + [t[ind] / global_batch_size for ind in range(len(t))] + ), + ) + for ind, (loss, _) in enumerate(loss_fun.groups): + losses[f"train {loss.tag} avg"] = ResultItem(reduced_individual_losses[ind], decimal_places=2) + + loss_fun.reset_cumulated_individual_losses() + consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total]) metrics = { "consumed tokens": ResultItem(consumed_tokens, 0),