Skip to content

Commit

Permalink
feat: add tracking of each distinct loss
Browse files Browse the repository at this point in the history
  • Loading branch information
manasMauryax committed Sep 26, 2024
1 parent ec67ed0 commit 6eceaba
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
13 changes: 12 additions & 1 deletion src/modalities/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,26 @@ def __init__(
raise ValueError("Number of losses used should be more than 1.")

self.groups = [(loss_func, weight) for loss_func, weight in zip(losses, corrsp_weights)]
self.cumulated_individual_losses = None
self.reset_cumulated_individual_losses()

def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
device = forward_batch.predictions[list(forward_batch.predictions.keys())[0]].device
total_loss = torch.tensor(0, dtype=torch.float, device=device)
for loss_func, weight in self.groups:
for ind, (loss_func, weight) in enumerate(self.groups):
loss = loss_func(forward_batch)
self.cumulated_individual_losses[ind] += loss
total_loss += weight * loss
return total_loss

def reset_cumulated_individual_losses(
self,
) -> None:
if torch.cuda.is_available():
self.cumulated_individual_losses = torch.zeros(len(self.groups)).to(torch.device("cuda"))
else:
self.cumulated_individual_losses = torch.zeros(len(self.groups)).to("cpu")


class CLMCrossEntropyLoss(Loss):
def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"):
Expand Down
18 changes: 17 additions & 1 deletion src/modalities/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.loss_functions import Loss
from modalities.loss_functions import Loss, MultipleFunctionsLoss
from modalities.models.model import model_predict_batch
from modalities.running_env.fsdp.reducer import Reducer
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
Expand Down Expand Up @@ -259,6 +259,22 @@ def train(
"train loss last": ResultItem(train_loss_last_batch, decimal_places=2),
}

if isinstance(loss_fun, MultipleFunctionsLoss):
global_batch_size = Reducer.reduce(
tensor=cumulated_losses[-1], operation=dist.ReduceOp.SUM, post_processing_fun=None
)
reduced_individual_losses = Reducer.reduce(
tensor=loss_fun.cumulated_individual_losses,
operation=dist.ReduceOp.SUM,
post_processing_fun=lambda t: torch.stack(
[t[ind] / global_batch_size for ind in range(len(t))]
),
)
for ind, (loss, _) in enumerate(loss_fun.groups):
losses[f"train {loss.tag} avg"] = ResultItem(reduced_individual_losses[ind], decimal_places=2)

loss_fun.reset_cumulated_individual_losses()

consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total])
metrics = {
"consumed tokens": ResultItem(consumed_tokens, 0),
Expand Down

0 comments on commit 6eceaba

Please sign in to comment.