diff --git a/dolomite_engine/hf_models/mixins/dense_TP/main.py b/dolomite_engine/hf_models/mixins/dense_TP/main.py index 52d4599..3d7059c 100644 --- a/dolomite_engine/hf_models/mixins/dense_TP/main.py +++ b/dolomite_engine/hf_models/mixins/dense_TP/main.py @@ -212,7 +212,7 @@ def load_from_safetensors_weights_manager(self, safetensors_weights_manager: Saf def get_dummy_input_tensor( self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype - ) -> tuple[int]: + ) -> tuple[torch.Tensor] | torch.Tensor: if self.is_first_stage: # 1 is added to sequence length since megatron's dataloader gives an extra token and for good reason tensor = torch.empty( @@ -231,7 +231,7 @@ def get_dummy_output_tensor( sequence_length: int, intermediate_dtype: torch.dtype, output_parallel_lm_logits_if_possible: bool, - ) -> tuple[int]: + ) -> tuple[torch.Tensor] | torch.Tensor: if self.is_last_stage: vocab_size = self.config.vocab_size if self.tensor_parallel_word_embeddings and output_parallel_lm_logits_if_possible: @@ -261,7 +261,7 @@ def get_dummy_output_tensor( def _get_dummy_intermediate_tensor( self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype - ) -> tuple[int]: + ) -> tuple[torch.Tensor] | torch.Tensor: sharded_sequence_length = ( divide_if_divisible(sequence_length, ProcessGroupManager.get_tensor_parallel_world_size(), "") if self.sequence_parallel diff --git a/dolomite_engine/hf_models/mixins/moe_TP/base.py b/dolomite_engine/hf_models/mixins/moe_TP/base.py index fddbe56..7382c44 100644 --- a/dolomite_engine/hf_models/mixins/moe_TP/base.py +++ b/dolomite_engine/hf_models/mixins/moe_TP/base.py @@ -1,6 +1,6 @@ import torch.nn as nn -from ....utils import ProcessGroupManager +from ....utils import ProcessGroupManager, divide_if_divisible from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils_TP import Dropout_TP, Embedding_TP, get_normalization_function_TP @@ -26,27 +26,37 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.initializer_range = config.initializer_range self.head_dim = self.embed_dim // self.num_heads - self.wte = Embedding_TP( - config.vocab_size, - self.embed_dim, - std=self.initializer_range, - tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, - use_padding_free_transformer=self._use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, + self.layers_per_stage = divide_if_divisible( + config.n_layer, self.num_pipeline_stages, "layers should be divisible by num_pipeline_stages" ) - self.drop = ( - nn.Identity() - if config.embd_pdrop == 0 - else Dropout_TP( - config.embd_pdrop, + self.layer_start_id = self.layers_per_stage * self.pipeline_stage_id + self.layer_end_id = self.layers_per_stage * (self.pipeline_stage_id + 1) + + + if self.is_first_stage: + self.wte = Embedding_TP( + config.vocab_size, + self.embed_dim, + std=self.initializer_range, + tensor_parallel_word_embeddings=self.tensor_parallel_word_embeddings, use_padding_free_transformer=self._use_padding_free_transformer, sequence_parallel=self.sequence_parallel, ) - ) - self.h = nn.ModuleList( - [ - self.layer_class( + + self.drop = ( + nn.Identity() + if config.embd_pdrop == 0 + else Dropout_TP( + config.embd_pdrop, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + ) + + self.h = nn.ModuleDict( + { + str(i): self.layer_class( config, normalization_implementation=self.normalization_implementation, attention_implementation=self.attention_implementation, @@ -56,17 +66,19 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: sequence_parallel=self.sequence_parallel, ) for i in range(config.num_hidden_layers) - ] - ) - self.ln_f = get_normalization_function_TP( - config.normalization_function, - self.embed_dim, - eps=config.layer_norm_epsilon, - normalization_implementation=self.normalization_implementation, - use_padding_free_transformer=self._use_padding_free_transformer, - sequence_parallel=self.sequence_parallel, + } ) + if self.is_last_stage: + self.ln_f = get_normalization_function_TP( + config.normalization_function, + self.embed_dim, + eps=config.layer_norm_epsilon, + normalization_implementation=self.normalization_implementation, + use_padding_free_transformer=self._use_padding_free_transformer, + sequence_parallel=self.sequence_parallel, + ) + self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) self._setup_positional_encoding() diff --git a/dolomite_engine/hf_models/mixins/moe_TP/main.py b/dolomite_engine/hf_models/mixins/moe_TP/main.py index 71d926e..6be305d 100644 --- a/dolomite_engine/hf_models/mixins/moe_TP/main.py +++ b/dolomite_engine/hf_models/mixins/moe_TP/main.py @@ -27,20 +27,23 @@ def forward( max_seqlen: torch.Tensor | None = None, output_router_logits: bool | None = None, ) -> tuple | MoeCausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + # Note: `past_key_values` contains aux_loss in non-stage-one PP + if not self.is_pipeline_parallel_enabled or self.is_first_stage: + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + transformer_outputs: MoeModelOutputWithPastAndAuxLoss = self.transformer( input_ids, past_key_values=past_key_values, @@ -55,35 +58,76 @@ def forward( output_router_logits=output_router_logits, ) - lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + if not self.is_pipeline_parallel_enabled or self.is_last_stage: + lm_logits = self.get_lm_logits(transformer_outputs.last_hidden_state) + + if self.m_width is not None: + lm_logits = lm_logits / self.m_width - if self.m_width is not None: - lm_logits = lm_logits / self.m_width + if not self.is_pipeline_parallel_enabled: + lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) - lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) aux_loss = tensor_to_dtensor( transformer_outputs.aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate() ) + if self.is_pipeline_parallel_enabled and not self.is_first_stage: + # Non-first-stage PP will have aux loss as `past_key_values` due to the way pipeline works + # So accumulate aux_loss + prev_aux_loss = past_key_values + aux_loss = aux_loss + prev_aux_loss - if lm_loss is None: - loss = None + if not self.is_pipeline_parallel_enabled or self.is_last_stage: + if output_parallel_lm_logits: + assert self.tensor_parallel_word_embeddings + else: + if self.tensor_parallel_word_embeddings: + # all gather + lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) + lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + + if not self.is_pipeline_parallel_enabled: + if lm_loss is None: + loss = None + else: + loss = lm_loss + self.router_aux_loss_coef * aux_loss + output = MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + router_logits=transformer_outputs.router_logits, + ) + elif self.is_last_stage: + output = (lm_logits, aux_loss) else: - loss = lm_loss + self.router_aux_loss_coef * aux_loss + output = (transformer_outputs.last_hidden_state, aux_loss) + return output + + + + + def get_dummy_input_tensor(self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype) -> tuple[Tensor] | torch.Tensor: + dummy_input = super().get_dummy_input_tensor(micro_batch_size, sequence_length, intermediate_dtype) - if output_parallel_lm_logits: - assert self.tensor_parallel_word_embeddings + if self.is_first_stage: + return dummy_input else: - if self.tensor_parallel_word_embeddings: - # all gather - lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - router_logits=transformer_outputs.router_logits, - ) + aux_loss_dummy = torch.tensor(0., device=torch.cuda.current_device(), dtype=intermediate_dtype) + if isinstance(tuple, dummy_input): + return dummy_input + (aux_loss_dummy,) + else: + return (dummy_input, aux_loss_dummy) + + + + + def get_dummy_output_tensor(self, micro_batch_size: int, sequence_length: int, intermediate_dtype: torch.dtype, output_parallel_lm_logits_if_possible: bool) -> tuple[int]: + dummy_output = super().get_dummy_output_tensor(micro_batch_size, sequence_length, intermediate_dtype, output_parallel_lm_logits_if_possible) + dummy_aux_loss = torch.tensor(0., device=torch.cuda.current_device(), dtype=intermediate_dtype) + if isinstance(tuple, dummy_output): + return dummy_output + (dummy_aux_loss,) + else: + return (dummy_output, dummy_aux_loss) +