Skip to content

Commit

Permalink
is_tp_enabled
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Oct 24, 2024
1 parent 9313565 commit 02459bf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
6 changes: 3 additions & 3 deletions dolomite_engine/hf_models/mixins/dense_TP/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.initializer_range = config.initializer_range
self.head_dim = self.embed_dim // self.num_heads

self.tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size()

self.layers_per_stage = divide_if_divisible(
config.n_layer, self.num_stages, "layers should be divisible by num_stages"
)
Expand Down Expand Up @@ -158,7 +156,9 @@ def forward(
query_length = key_length - past_length
else:
key_length = (
hidden_states.size(1) * self.tp_world_size if self.sequence_parallel else hidden_states.size(1)
hidden_states.size(1) * ProcessGroupManager.get_tensor_parallel_world_size()
if self.sequence_parallel
else hidden_states.size(1)
)
query_length = key_length - past_length

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ....utils import ProcessGroupManager
from ....utils import ProcessGroupManager, divide_if_divisible
from ...modeling_utils import Alibi


Expand All @@ -20,8 +20,7 @@ def reset_parameters(self) -> None:
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)

tp_rank = ProcessGroupManager.get_tensor_parallel_rank()
tp_world_size = ProcessGroupManager.get_tensor_parallel_world_size()
num_heads_tp = self.num_heads // tp_world_size
num_heads_tp = divide_if_divisible(self.num_heads, ProcessGroupManager.get_tensor_parallel_world_size(), "")
slopes = slopes[tp_rank * num_heads_tp : (tp_rank + 1) * num_heads_tp]

self.register_buffer("slopes", slopes, persistent=False)

0 comments on commit 02459bf

Please sign in to comment.