From 619dc652e1993ac7f513c66023b0d8a79fcea10f Mon Sep 17 00:00:00 2001
From: manasMauryax
Date: Wed, 25 Sep 2024 13:55:50 +0200
Subject: [PATCH] feat: add tracking of distinct losses
---
src/modalities/loss_functions.py | 24 +++++++++++++++++++++++-
src/modalities/trainer.py | 21 ++++++++++++++++++++-
2 files changed, 43 insertions(+), 2 deletions(-)
diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py
index 6e7f8501..c90dde33 100644
--- a/src/modalities/loss_functions.py
+++ b/src/modalities/loss_functions.py
@@ -54,14 +54,36 @@ def __init__(
self.groups = [(loss_func, weight) for loss_func, weight in zip(losses, corrsp_weights)]
+ self.cumulated_individual_losses = None
+ # variable storing each loss,
+ # summed over local batches,
+ # separately.
+
+ self.reset_cumulated_individual_losses()
+
def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
device = forward_batch.predictions[list(forward_batch.predictions.keys())[0]].device
total_loss = torch.tensor(0, dtype=torch.float, device=device)
- for loss_func, weight in self.groups:
+ for ind, (loss_func, weight) in enumerate(self.groups):
loss = loss_func(forward_batch)
+ self.cumulated_individual_losses[ind] += loss
total_loss += weight * loss
return total_loss
+ def reset_cumulated_individual_losses(
+ self,
+ ) -> None:
+ """Initializes and resets the variable
+ accumulating each loss separately.
+
+ Called first when the class is initialized, and then
+ after every logging step in trainer.py.
+ """
+ if torch.cuda.is_available():
+ self.cumulated_individual_losses = torch.zeros(len(self.groups)).to(torch.device("cuda"))
+ else:
+ self.cumulated_individual_losses = torch.zeros(len(self.groups)).to("cpu")
+
class CLMCrossEntropyLoss(Loss):
def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"):
diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py
index e15e47c0..798b569a 100644
--- a/src/modalities/trainer.py
+++ b/src/modalities/trainer.py
@@ -12,7 +12,7 @@
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
-from modalities.loss_functions import Loss
+from modalities.loss_functions import Loss, MultipleFunctionsLoss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
@@ -259,6 +259,25 @@ def train(
"train loss last": ResultItem(train_loss_last_batch, decimal_places=2),
}
+ # If there are multiple loss functions being used,
+ # this block computes and logs all the individual
+ # losses, averaged over the global batch size.
+ if isinstance(loss_fun, MultipleFunctionsLoss):
+ global_batch_size = Reducer.reduce(
+ tensor=cumulated_losses[-1], operation=dist.ReduceOp.SUM, post_processing_fun=None
+ )
+ reduced_individual_losses = Reducer.reduce(
+ tensor=loss_fun.cumulated_individual_losses,
+ operation=dist.ReduceOp.SUM,
+ post_processing_fun=lambda t: torch.stack(
+ [t[ind] / global_batch_size for ind in range(len(t))]
+ ),
+ )
+ for ind, (loss, _) in enumerate(loss_fun.groups):
+ losses[f"train {loss.tag} avg"] = ResultItem(reduced_individual_losses[ind], decimal_places=2)
+
+ loss_fun.reset_cumulated_individual_losses()
+
consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total])
metrics = {
"consumed tokens": ResultItem(consumed_tokens, 0),