Skip to content

Commit

Permalink
frequency averaging
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 30, 2024
1 parent 214f6f9 commit b5a9410
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions dolomite_engine/hf_models/models/moe_dolomite/moe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._functional_collectives import all_reduce

from .....utils import ProcessGroupManager
from ....enums import InitMethod
from ....modeling_utils import ParameterizedTransposedLinear, get_activation_function, is_glu
from ..config import MoEDolomiteConfig
Expand Down Expand Up @@ -207,6 +209,9 @@ def _compute_switch_loss(self, logits: torch.Tensor, probs: torch.Tensor, topk_i
acc_probs = probs.sum(0)
freq = torch.bincount(topk_idxs.flatten(), minlength=num_experts).to(dtype=logits.dtype)

if ProcessGroupManager.get_data_parallel_world_size() > 1:
freq = all_reduce(freq, reduceOp="sum", group=ProcessGroupManager.get_data_parallel_group())

switch_loss = num_experts * (F.normalize(acc_probs, p=1, dim=0) * F.normalize(freq, p=1, dim=0)).sum()
z_loss = (torch.logsumexp(logits, dim=-1) ** 2).mean()

Expand Down

0 comments on commit b5a9410

Please sign in to comment.