From 00b572eca997a75cedd90196dfcc062b80196dd4 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Fri, 7 Mar 2025 17:06:15 +0800 Subject: [PATCH 1/2] [Distribution] Support DualPipeV for DeepSeek --- llm/config/deepseek-v3/pretrain_argument.json | 13 +++-- .../transformers/deepseek_v2/configuration.py | 2 + .../transformers/deepseek_v2/modeling.py | 27 +++++---- .../transformers/deepseek_v2/modeling_pp.py | 57 +++++++++++++++++-- paddlenlp/transformers/gpt/modeling.py | 2 + paddlenlp/transformers/gpt/modeling_pp.py | 5 +- paddlenlp/transformers/moe_layer.py | 2 +- 7 files changed, 84 insertions(+), 24 deletions(-) diff --git a/llm/config/deepseek-v3/pretrain_argument.json b/llm/config/deepseek-v3/pretrain_argument.json index 3c312379acb9..c7ff267cb3f9 100644 --- a/llm/config/deepseek-v3/pretrain_argument.json +++ b/llm/config/deepseek-v3/pretrain_argument.json @@ -4,11 +4,11 @@ "input_dir": "./data", "output_dir": "./checkpoints/pretrain_ckpts", "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 4, + "gradient_accumulation_steps": 10, "per_device_eval_batch_size": 1, "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 1, - "sharding_parallel_degree": 8, + "pipeline_parallel_degree": 2, + "sharding_parallel_degree": 4, "expert_parallel_degree": 4, "sharding": "stage1", "virtual_pp_degree": 1, @@ -33,9 +33,10 @@ "do_eval": true, "do_predict": false, "disable_tqdm": true, - "recompute": true, + "recompute": false, "distributed_dataloader": 1, "recompute_granularity": "full", "unified_checkpoint": true, - "save_total_limit": 2 -} \ No newline at end of file + "save_total_limit": 2, + "pipeline_parallel_config": "use_dualpipev" +} diff --git a/paddlenlp/transformers/deepseek_v2/configuration.py b/paddlenlp/transformers/deepseek_v2/configuration.py index d21afc20780f..f7c2929a76d0 100644 --- a/paddlenlp/transformers/deepseek_v2/configuration.py +++ b/paddlenlp/transformers/deepseek_v2/configuration.py @@ -179,6 +179,7 @@ def __init__( attention_dropout=0.0, speculate_model_type=False, using_flex_token=False, + use_dualpipev=False, **kwargs, ): self.vocab_size = vocab_size @@ -227,6 +228,7 @@ def __init__( self.speculate_model_type = speculate_model_type self.use_fp8 = False self.using_flex_token = using_flex_token + self.use_dualpipev = use_dualpipev super().__init__( pad_token_id=pad_token_id, diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index a166c088b028..e9b769a4bc7c 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -130,7 +130,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): return assignment_list -def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 try: @@ -148,7 +148,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) - logits = paddle.matmul(input_parallel, y, transpose_y=False) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) if tensor_parallel_output: return logits @@ -156,7 +156,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(x, y, transpose_y=False) + logits = paddle.matmul(x, y, transpose_y=transpose_y) return logits @@ -1892,7 +1892,7 @@ def add_loss(main_loss, loss): class DeepseekV2LMHead(nn.Layer): - def __init__(self, config: DeepseekV2Config): + def __init__(self, config: DeepseekV2Config, embedding_weight=None): super(DeepseekV2LMHead, self).__init__() self.config = config @@ -1906,11 +1906,16 @@ def __init__(self, config: DeepseekV2Config): else: vocab_size = config.vocab_size - self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.XavierNormal(1.0), - ) + if embedding_weight is not None: + self.transpose_y = True + self.weight = embedding_weight + else: + self.transpose_y = False + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.XavierNormal(1.0), + ) # Must set distributed attr for Tensor Parallel ! self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False @@ -1922,7 +1927,9 @@ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) return logits def extra_repr(self): diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 091d44c8c8f2..70c288cdb1a4 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -39,13 +39,18 @@ DeepseekV2RMSNorm, ) +try: + from paddle.distributed.fleet.meta_parallel import LocalSharedLayerDesc +except: + LocalSharedLayerDesc = None + __all__ = [ "DeepseekV2ForCausalLMPipe", ] def parse_args(args): - if isinstance(args, tuple): + if isinstance(args, (tuple, list)): if len(args) == 4: hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args @@ -55,6 +60,9 @@ def parse_args(args): elif len(args) == 2: hidden_states, attention_mask = args attn_mask_startend_row_indices, position_ids = None, None + else: # len(args) == 1: + hidden_states = args[0] + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None else: hidden_states = args attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None @@ -321,8 +329,8 @@ def forward(self, args): class DeepseekV2LMHeadPipe(DeepseekV2LMHead): - def __init__(self, config): - super(DeepseekV2LMHeadPipe, self).__init__(config) + def __init__(self, config, embedding_weight=None): + super(DeepseekV2LMHeadPipe, self).__init__(config, embedding_weight=embedding_weight) @property def embedding_weight(self): @@ -406,6 +414,10 @@ def __init__(self, config: DeepseekV2Config): assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + use_dualpipev = getattr(self.config, "use_dualpipev", False) + if use_dualpipev: + assert LocalSharedLayerDesc is not None, "LocalSharedLayerDesc is None, please update your paddle." + shared_class = LocalSharedLayerDesc if use_dualpipev else SharedLayerDesc def get_hcg(): return fleet.get_hybrid_communicate_group() @@ -420,7 +432,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2EmbeddingPipe, shared_weight_attr="embedding_weight", @@ -453,12 +465,11 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2LMHeadPipe, shared_weight_attr="embedding_weight", config=config, - **{"transpose_y": True}, ), "lm_head", ) @@ -489,6 +500,7 @@ def get_hcg(): "partition": False, }, num_virtual_pipeline_stages=virtual_pp_degree, + use_dualpipev=use_dualpipev, ) # You should call init here, since there is a diamond inheritance problem self.apply(self._init_weights) @@ -497,3 +509,36 @@ def get_hcg(): def get_loss_fn(self, config): return DeepseekV2PretrainingCriterionPipe(config) + + def overlapped_forward_backward( + self, + module0, # the module of the forward chunk + inputs0, + criterion0, + labels0, + module1, # the module of the backward chunk, maybe not used + loss1, + outputs1, + output_grads1, + scaler, + ): + outputs0 = module0(inputs0) + outputs0 = [outputs0] if isinstance(outputs0, paddle.Tensor) else outputs0 + + if labels0 is not None: + loss0 = criterion0(outputs0, labels0) + else: + loss0 = None + + if loss1 is not None: + if scaler: + paddle.autograd.backward(scaler.scale(loss1)) + else: + paddle.autograd.backward(loss1) + else: + paddle.autograd.backward( + tensors=outputs1, + grad_tensors=output_grads1, + ) + + return outputs0, loss0 diff --git a/paddlenlp/transformers/gpt/modeling.py b/paddlenlp/transformers/gpt/modeling.py index 16cf34aaa24d..2435d793bd10 100644 --- a/paddlenlp/transformers/gpt/modeling.py +++ b/paddlenlp/transformers/gpt/modeling.py @@ -1350,6 +1350,8 @@ def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): Tensor: The pretraining loss. Its data type should be float32 and its shape is [1]. """ + if isinstance(prediction_scores, list): + prediction_scores = prediction_scores[0] with paddle.amp.auto_cast(False): masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) # skip ignore_index which loss == 0 diff --git a/paddlenlp/transformers/gpt/modeling_pp.py b/paddlenlp/transformers/gpt/modeling_pp.py index d6b0dcc7f73d..c9d59a1c63f0 100644 --- a/paddlenlp/transformers/gpt/modeling_pp.py +++ b/paddlenlp/transformers/gpt/modeling_pp.py @@ -60,12 +60,15 @@ def get_attr(layer, name): def parse_args(args): - if isinstance(args, tuple): + if isinstance(args, (tuple, list)): if len(args) == 3: hidden_states, attention_mask, position_ids = args elif len(args) == 2: hidden_states, attention_mask = args position_ids = None + else: + hidden_states = args[0] + attention_mask, position_ids = None, None else: hidden_states = args attention_mask, position_ids = None, None diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 63d3b7330219..005e7f98a7fc 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -305,7 +305,7 @@ def expert_forward(self, dispatched_input, tokens_per_expert): chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) for chunk, expert in zip(chunks, self.experts): chunk = chunk.contiguous() - assert chunk.shape[0] != 0, "Cannot dispatch empty input" + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" outputs += [expert(chunk)] return paddle.concat(outputs, axis=0) From c38b4b498774fec13b104b239c3c1ebbdbfc32c3 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 Date: Sun, 9 Mar 2025 12:16:00 +0800 Subject: [PATCH 2/2] fix overlap --- .../transformers/deepseek_v2/modeling.py | 108 ++++++++++++++++-- .../transformers/deepseek_v2/modeling_pp.py | 5 +- paddlenlp/transformers/moe_layer.py | 21 +++- 3 files changed, 118 insertions(+), 16 deletions(-) diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index e9b769a4bc7c..52c907bac0f4 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -826,6 +826,10 @@ def __init__(self, config: DeepseekV2Config): def forward(self, hidden_states): final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + final_hidden_states = self.auxilibaryloss_and_shared_expert_compute(hidden_states, final_hidden_states, l_aux) + return final_hidden_states + + def auxilibaryloss_and_shared_expert_compute(self, hidden_states, final_hidden_states, l_aux): if self.training and self.alpha > 0.0: l_aux = l_aux * self.alpha final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) @@ -1145,6 +1149,48 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute self.input_layernorm = DeepseekV2RMSNorm(config) self.post_attention_layernorm = DeepseekV2RMSNorm(config) + def self_attn_and_gate_compute( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ): + hidden_states, residual = self.self_attn_compute( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + probs, routing_map, l_aux, l_zloss = self.mlp.gate_compute(hidden_states) + return probs, routing_map, l_aux, l_zloss + + def auxilibaryloss_and_shared_expert_compute(self, residual, hidden_states, expert_output, l_aux): + hidden_states = self.mlp.auxilibaryloss_and_shared_expert_compute(hidden_states, expert_output, l_aux) + hidden_states = residual + hidden_states + + def post_process_output(self, hidden_states, output_attentions, use_cache, self_attn_weights, present_key_value): + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + def forward( self, hidden_states: paddle.Tensor, @@ -1170,10 +1216,6 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states """ - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -1216,18 +1258,60 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + return self.post_process_output( + hidden_states, output_attentions, use_cache, self_attn_weights, present_key_value + ) - if output_attentions: - outputs += (self_attn_weights,) + def self_attn_compute( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs + ): + residual = hidden_states - if use_cache: - outputs += (present_key_value,) + hidden_states = self.input_layernorm(hidden_states) - if type(outputs) is tuple and len(outputs) == 1: - outputs = outputs[0] + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + hidden_states, self_attn_weights, present_key_value = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + else: + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + hidden_states = residual + hidden_states - return outputs + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + return hidden_states, residual class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): diff --git a/paddlenlp/transformers/deepseek_v2/modeling_pp.py b/paddlenlp/transformers/deepseek_v2/modeling_pp.py index 70c288cdb1a4..dd917b93c946 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling_pp.py +++ b/paddlenlp/transformers/deepseek_v2/modeling_pp.py @@ -522,7 +522,10 @@ def overlapped_forward_backward( output_grads1, scaler, ): - outputs0 = module0(inputs0) + outputs0 = inputs0 + for layer in module0: + outputs0 = layer(outputs0) + outputs0 = [outputs0] if isinstance(outputs0, paddle.Tensor) else outputs0 if labels0 is not None: diff --git a/paddlenlp/transformers/moe_layer.py b/paddlenlp/transformers/moe_layer.py index 005e7f98a7fc..dfd61cf1a666 100644 --- a/paddlenlp/transformers/moe_layer.py +++ b/paddlenlp/transformers/moe_layer.py @@ -311,12 +311,27 @@ def expert_forward(self, dispatched_input, tokens_per_expert): return paddle.concat(outputs, axis=0) def forward(self, hidden_states: paddle.Tensor): + probs, routing_map, l_aux, l_zloss = self.gate_compute(hidden_states) + dispatched_input, tokens_per_expert = self.dispatch_comm(hidden_states, probs, routing_map) + expert_output = self.mlp_compute(dispatched_input, tokens_per_expert) + output = self.combine_comm(expert_output) + return output, l_aux, l_zloss + + def gate_compute(self, hidden_states): _, _, d_model = hidden_states.shape # reshaped_input = hidden_states.reshape([-1, d_model]) probs, routing_map, l_aux, l_zloss = self.router(hidden_states) - (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation( + return probs, routing_map, l_aux, l_zloss + + def dispatch_comm(self, hidden_states, probs, routing_map): + dispatched_input, tokens_per_expert = self.token_dispatcher.token_permutation( hidden_states, probs, routing_map ) - expert_output = self.expert_forward(dispatched_input, tokens_per_expert) + return dispatched_input, tokens_per_expert + + def mlp_compute(self, dispatched_input, tokens_per_expert): + return self.expert_forward(dispatched_input, tokens_per_expert) + + def combine_comm(self, expert_output): output, _ = self.token_dispatcher.token_unpermutation(expert_output, None) - return output, l_aux, l_zloss + return output