Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin MoE pipeline. #57

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions dolomite_engine/hf_models/mixins/dense_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def load_from_safetensors_weights_manager(self, safetensors_weights_manager: Saf

def get_dummy_input_tensor(
self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype
) -> tuple[int]:
) -> tuple[torch.Tensor] | torch.Tensor:
if self.is_first_stage:
# 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason
tensor = torch.empty(
Expand All @@ -231,7 +231,7 @@ def get_dummy_output_tensor(
sequence_length: int,
intermediate_dtype: torch.dtype,
output_parallel_lm_logits_if_possible: bool,
) -> tuple[int]:
) -> tuple[torch.Tensor] | torch.Tensor:
if self.is_last_stage:
vocab_size = self.config.vocab_size
if self.tensor_parallel_word_embeddings and output_parallel_lm_logits_if_possible:
Expand Down Expand Up @@ -261,7 +261,7 @@ def get_dummy_output_tensor(

def _get_dummy_intermediate_tensor(
self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype
) -> tuple[int]:
) -> tuple[torch.Tensor] | torch.Tensor:
sharded_sequence_length = (
divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "")
if self.sequence_parallel
Expand Down
64 changes: 38 additions & 26 deletions dolomite_engine/hf_models/mixins/moe_TP/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn

from ....utils import ProcessGroupManager
from ....utils import ProcessGroupManager, divide_if_divisible
from ...config import CommonConfig
from ...enums import AttentionHeadType, PositionEmbeddingType
from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP
Expand All @@ -26,27 +26,37 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
self.initializer_range = config.initializer_range
self.head_dim = self.embed_dim // self.num_heads

self.wte = Embedding_TP(
config.vocab_size,
self.embed_dim,
std=self.initializer_range,
tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
self.layers_per_stage = divide_if_divisible(
config.n_layer, self.num_pipeline_stages, "layers should be divisible by num_pipeline_stages"
)

self.drop = (
nn.Identity()
if config.embd_pdrop == 0
else Dropout_TP(
config.embd_pdrop,
self.layer_start_id = self.layers_per_stage * self.pipeline_stage_id
self.layer_end_id = self.layers_per_stage * (self.pipeline_stage_id + 1)


if self.is_first_stage:
self.wte = Embedding_TP(
config.vocab_size,
self.embed_dim,
std=self.initializer_range,
tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)
)
self.h = nn.ModuleList(
[
self.layer_class(

self.drop = (
nn.Identity()
if config.embd_pdrop == 0
else Dropout_TP(
config.embd_pdrop,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)
)

self.h = nn.ModuleDict(
{
str(i): self.layer_class(
config,
normalization_implementation=self.normalization_implementation,
attention_implementation=self.attention_implementation,
Expand All @@ -56,17 +66,19 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None:
sequence_parallel=self.sequence_parallel,
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = get_normalization_function_TP(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
normalization_implementation=self.normalization_implementation,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
}
)

if self.is_last_stage:
self.ln_f = get_normalization_function_TP(
config.normalization_function,
self.embed_dim,
eps=config.layer_norm_epsilon,
normalization_implementation=self.normalization_implementation,
use_padding_free_transformer=self._use_padding_free_transformer,
sequence_parallel=self.sequence_parallel,
)

self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type)
self._setup_positional_encoding()

Expand Down
116 changes: 80 additions & 36 deletions dolomite_engine/hf_models/mixins/moe_TP/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,23 @@ def forward(
max_seqlen: torch.Tensor | None = None,
output_router_logits: bool | None = None,
) -> tuple | MoeCausalLMOutputWithPast:
input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
token_type_ids=token_type_ids,
labels=labels,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
# Note: `past_key_values` contains aux_loss in non-stage-one PP

if not self.is_pipeline_parallel_enabled or self.is_first_stage:
input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
position_ids=position_ids,
token_type_ids=token_type_ids,
labels=labels,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

transformer_outputs: MoeModelOutputWithPastAndAuxLoss = self.transformer(
input_ids,
past_key_values=past_key_values,
Expand All @@ -55,35 +58,76 @@ def forward(
output_router_logits=output_router_logits,
)

lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)
if not self.is_pipeline_parallel_enabled or self.is_last_stage:
lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state)

if self.m_width is not None:
lm_logits = lm_logits / self.m_width

if self.m_width is not None:
lm_logits = lm_logits / self.m_width
if not self.is_pipeline_parallel_enabled:
lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens)

lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens)
aux_loss = tensor_to_dtensor(
transformer_outputs.aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate()
)
if self.is_pipeline_parallel_enabled and not self.is_first_stage:
# Non-first-stage PP will have aux loss as `past_key_values` due to the way pipeline works
# So accumulate aux_loss
prev_aux_loss = past_key_values
aux_loss = aux_loss + prev_aux_loss

if lm_loss is None:
loss = None
if not self.is_pipeline_parallel_enabled or self.is_last_stage:
if output_parallel_lm_logits:
assert self.tensor_parallel_word_embeddings
else:
if self.tensor_parallel_word_embeddings:
# all gather
lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1))
lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate())

if not self.is_pipeline_parallel_enabled:
if lm_loss is None:
loss = None
else:
loss = lm_loss + self.router_aux_loss_coef * aux_loss
output = MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
router_logits=transformer_outputs.router_logits,
)
elif self.is_last_stage:
output = (lm_logits, aux_loss)
else:
loss = lm_loss + self.router_aux_loss_coef * aux_loss
output = (transformer_outputs.last_hidden_state, aux_loss)
return output




def get_dummy_input_tensor(self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype) -> tuple[Tensor] | torch.Tensor:
dummy_input = super().get_dummy_input_tensor(micro_batch_size, sequence_length, intermediate_dtype)

if output_parallel_lm_logits:
assert self.tensor_parallel_word_embeddings
if self.is_first_stage:
return dummy_input
else:
if self.tensor_parallel_word_embeddings:
# all gather
lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1))
lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate())

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
router_logits=transformer_outputs.router_logits,
)
aux_loss_dummy = torch.tensor(0., device=torch.cuda.current_device(), dtype=intermediate_dtype)
if isinstance(tuple, dummy_input):
return dummy_input + (aux_loss_dummy,)
else:
return (dummy_input, aux_loss_dummy)




def get_dummy_output_tensor(self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype, output_parallel_lm_logits_if_possible: bool) -> tuple[int]:
dummy_output = super().get_dummy_output_tensor(micro_batch_size, sequence_length, intermediate_dtype, output_parallel_lm_logits_if_possible)
dummy_aux_loss = torch.tensor(0., device=torch.cuda.current_device(), dtype=intermediate_dtype)
if isinstance(tuple, dummy_output):
return dummy_output + (dummy_aux_loss,)
else:
return (dummy_output, dummy_aux_loss)

Loading