Skip to content

Commit

Permalink
feat: add tracking of distinct losses
Browse files Browse the repository at this point in the history
  • Loading branch information
manasMauryax committed Sep 26, 2024
1 parent ec67ed0 commit 619dc65
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
24 changes: 23 additions & 1 deletion src/modalities/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,36 @@ def __init__(

self.groups = [(loss_func, weight) for loss_func, weight in zip(losses, corrsp_weights)]

self.cumulated_individual_losses = None
# variable storing each loss,
# summed over local batches,
# separately.

self.reset_cumulated_individual_losses()

def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
device = forward_batch.predictions[list(forward_batch.predictions.keys())[0]].device
total_loss = torch.tensor(0, dtype=torch.float, device=device)
for loss_func, weight in self.groups:
for ind, (loss_func, weight) in enumerate(self.groups):
loss = loss_func(forward_batch)
self.cumulated_individual_losses[ind] += loss
total_loss += weight * loss
return total_loss

def reset_cumulated_individual_losses(
self,
) -> None:
"""Initializes and resets the variable
accumulating each loss separately.
Called first when the class is initialized, and then
after every logging step in trainer.py.
"""
if torch.cuda.is_available():
self.cumulated_individual_losses = torch.zeros(len(self.groups)).to(torch.device("cuda"))
else:
self.cumulated_individual_losses = torch.zeros(len(self.groups)).to("cpu")


class CLMCrossEntropyLoss(Loss):
def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"):
Expand Down
21 changes: 20 additions & 1 deletion src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.loss_functions import Loss
from modalities.loss_functions import Loss, MultipleFunctionsLoss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
Expand Down Expand Up @@ -259,6 +259,25 @@ def train(
"train loss last": ResultItem(train_loss_last_batch, decimal_places=2),
}

# If there are multiple loss functions being used,
# this block computes and logs all the individual
# losses, averaged over the global batch size.
if isinstance(loss_fun, MultipleFunctionsLoss):
global_batch_size = Reducer.reduce(
tensor=cumulated_losses[-1], operation=dist.ReduceOp.SUM, post_processing_fun=None
)
reduced_individual_losses = Reducer.reduce(
tensor=loss_fun.cumulated_individual_losses,
operation=dist.ReduceOp.SUM,
post_processing_fun=lambda t: torch.stack(
[t[ind] / global_batch_size for ind in range(len(t))]
),
)
for ind, (loss, _) in enumerate(loss_fun.groups):
losses[f"train {loss.tag} avg"] = ResultItem(reduced_individual_losses[ind], decimal_places=2)

loss_fun.reset_cumulated_individual_losses()

consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total])
metrics = {
"consumed tokens": ResultItem(consumed_tokens, 0),
Expand Down

0 comments on commit 619dc65

Please sign in to comment.