Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support DualPipeV for DeepSeek #10030

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions llm/config/deepseek-v3/pretrain_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
"input_dir": "./data",
"output_dir": "./checkpoints/pretrain_ckpts",
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 4,
"gradient_accumulation_steps": 10,
"per_device_eval_batch_size": 1,
"tensor_parallel_degree": 1,
"pipeline_parallel_degree": 1,
"sharding_parallel_degree": 8,
"pipeline_parallel_degree": 2,
"sharding_parallel_degree": 4,
"expert_parallel_degree": 4,
"sharding": "stage1",
"virtual_pp_degree": 1,
Expand All @@ -33,9 +33,10 @@
"do_eval": true,
"do_predict": false,
"disable_tqdm": true,
"recompute": true,
"recompute": false,
"distributed_dataloader": 1,
"recompute_granularity": "full",
"unified_checkpoint": true,
"save_total_limit": 2
}
"save_total_limit": 2,
"pipeline_parallel_config": "use_dualpipev"
}
2 changes: 2 additions & 0 deletions paddlenlp/transformers/deepseek_v2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@
attention_dropout=0.0,
speculate_model_type=False,
using_flex_token=False,
use_dualpipev=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -227,6 +228,7 @@
self.speculate_model_type = speculate_model_type
self.use_fp8 = False
self.using_flex_token = using_flex_token
self.use_dualpipev = use_dualpipev

Check warning on line 231 in paddlenlp/transformers/deepseek_v2/configuration.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/configuration.py#L231

Added line #L231 was not covered by tests

super().__init__(
pad_token_id=pad_token_id,
Expand Down
135 changes: 113 additions & 22 deletions paddlenlp/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
return assignment_list


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
Expand All @@ -148,15 +148,15 @@
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)
logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y)

Check warning on line 151 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L151

Added line #L151 was not covered by tests

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
logits = paddle.matmul(x, y, transpose_y=transpose_y)

Check warning on line 159 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L159

Added line #L159 was not covered by tests
return logits


Expand Down Expand Up @@ -826,6 +826,10 @@

def forward(self, hidden_states):
final_hidden_states, l_aux, l_zloss = super().forward(hidden_states)
final_hidden_states = self.auxilibaryloss_and_shared_expert_compute(hidden_states, final_hidden_states, l_aux)
return final_hidden_states

Check warning on line 830 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L829-L830

Added lines #L829 - L830 were not covered by tests

def auxilibaryloss_and_shared_expert_compute(self, hidden_states, final_hidden_states, l_aux):
if self.training and self.alpha > 0.0:
l_aux = l_aux * self.alpha
final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux)
Expand Down Expand Up @@ -1145,6 +1149,48 @@
self.input_layernorm = DeepseekV2RMSNorm(config)
self.post_attention_layernorm = DeepseekV2RMSNorm(config)

def self_attn_and_gate_compute(
self,
hidden_states: paddle.Tensor,
position_ids: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs,
):
hidden_states, residual = self.self_attn_compute(

Check warning on line 1163 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1163

Added line #L1163 was not covered by tests
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)
probs, routing_map, l_aux, l_zloss = self.mlp.gate_compute(hidden_states)
return probs, routing_map, l_aux, l_zloss

Check warning on line 1174 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1173-L1174

Added lines #L1173 - L1174 were not covered by tests

def auxilibaryloss_and_shared_expert_compute(self, residual, hidden_states, expert_output, l_aux):
hidden_states = self.mlp.auxilibaryloss_and_shared_expert_compute(hidden_states, expert_output, l_aux)
hidden_states = residual + hidden_states

Check warning on line 1178 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1177-L1178

Added lines #L1177 - L1178 were not covered by tests

def post_process_output(self, hidden_states, output_attentions, use_cache, self_attn_weights, present_key_value):
outputs = (hidden_states,)

Check warning on line 1181 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1181

Added line #L1181 was not covered by tests

if output_attentions:
outputs += (self_attn_weights,)

Check warning on line 1184 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1183-L1184

Added lines #L1183 - L1184 were not covered by tests

if use_cache:
outputs += (present_key_value,)

Check warning on line 1187 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1186-L1187

Added lines #L1186 - L1187 were not covered by tests

if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]

Check warning on line 1190 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1189-L1190

Added lines #L1189 - L1190 were not covered by tests

return outputs

Check warning on line 1192 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1192

Added line #L1192 was not covered by tests

def forward(
self,
hidden_states: paddle.Tensor,
Expand All @@ -1170,10 +1216,6 @@
(see `past_key_values`).
past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states
"""
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand Down Expand Up @@ -1216,18 +1258,60 @@
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

outputs = (hidden_states,)
return self.post_process_output(

Check warning on line 1261 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1261

Added line #L1261 was not covered by tests
hidden_states, output_attentions, use_cache, self_attn_weights, present_key_value
)

if output_attentions:
outputs += (self_attn_weights,)
def self_attn_compute(
self,
hidden_states: paddle.Tensor,
position_ids: Optional[paddle.Tensor] = None,
attention_mask: Optional[paddle.Tensor] = None,
output_attentions: Optional[bool] = False,
past_key_value: Optional[Tuple[paddle.Tensor]] = None,
use_cache: Optional[bool] = False,
attn_mask_startend_row_indices: Optional[paddle.Tensor] = None,
**kwargs
):
residual = hidden_states

Check warning on line 1276 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1276

Added line #L1276 was not covered by tests

if use_cache:
outputs += (present_key_value,)
hidden_states = self.input_layernorm(hidden_states)

Check warning on line 1278 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1278

Added line #L1278 was not covered by tests

if type(outputs) is tuple and len(outputs) == 1:
outputs = outputs[0]
# Self Attention
has_gradient = not hidden_states.stop_gradient
if (

Check warning on line 1282 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1281-L1282

Added lines #L1281 - L1282 were not covered by tests
self.enable_recompute
and self.layerwise_recompute
and has_gradient
and self.recompute_granularity == "full_attn"
):
hidden_states, self_attn_weights, present_key_value = recompute(

Check warning on line 1288 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1288

Added line #L1288 was not covered by tests
self.self_attn,
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)
else:
hidden_states, self_attn_weights, present_key_value = self.self_attn(

Check warning on line 1300 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1300

Added line #L1300 was not covered by tests
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
past_key_value=past_key_value,
use_cache=use_cache,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
**kwargs,
)
hidden_states = residual + hidden_states

Check warning on line 1310 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1310

Added line #L1310 was not covered by tests

return outputs
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
return hidden_states, residual

Check warning on line 1314 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1312-L1314

Added lines #L1312 - L1314 were not covered by tests


class DeepseekV2MTPLayer(DeepseekV2DecoderLayer):
Expand Down Expand Up @@ -1892,7 +1976,7 @@


class DeepseekV2LMHead(nn.Layer):
def __init__(self, config: DeepseekV2Config):
def __init__(self, config: DeepseekV2Config, embedding_weight=None):
super(DeepseekV2LMHead, self).__init__()
self.config = config

Expand All @@ -1906,11 +1990,16 @@
else:
vocab_size = config.vocab_size

self.weight = self.create_parameter(
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.XavierNormal(1.0),
)
if embedding_weight is not None:
self.transpose_y = True
self.weight = embedding_weight

Check warning on line 1995 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1993-L1995

Added lines #L1993 - L1995 were not covered by tests
else:
self.transpose_y = False
self.weight = self.create_parameter(

Check warning on line 1998 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L1997-L1998

Added lines #L1997 - L1998 were not covered by tests
shape=[config.hidden_size, vocab_size],
dtype=paddle.get_default_dtype(),
default_initializer=nn.initializer.XavierNormal(1.0),
)
# Must set distributed attr for Tensor Parallel !
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False

Expand All @@ -1922,7 +2011,9 @@
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
logits = parallel_matmul(

Check warning on line 2014 in paddlenlp/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling.py#L2014

Added line #L2014 was not covered by tests
hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output
)
return logits

def extra_repr(self):
Expand Down
60 changes: 54 additions & 6 deletions paddlenlp/transformers/deepseek_v2/modeling_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,18 @@
DeepseekV2RMSNorm,
)

try:
from paddle.distributed.fleet.meta_parallel import LocalSharedLayerDesc
except:
LocalSharedLayerDesc = None

__all__ = [
"DeepseekV2ForCausalLMPipe",
]


def parse_args(args):
if isinstance(args, tuple):
if isinstance(args, (tuple, list)):

Check warning on line 53 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L53

Added line #L53 was not covered by tests
if len(args) == 4:
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args

Expand All @@ -55,6 +60,9 @@
elif len(args) == 2:
hidden_states, attention_mask = args
attn_mask_startend_row_indices, position_ids = None, None
else: # len(args) == 1:
hidden_states = args[0]
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None

Check warning on line 65 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L64-L65

Added lines #L64 - L65 were not covered by tests
else:
hidden_states = args
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None
Expand Down Expand Up @@ -321,8 +329,8 @@


class DeepseekV2LMHeadPipe(DeepseekV2LMHead):
def __init__(self, config):
super(DeepseekV2LMHeadPipe, self).__init__(config)
def __init__(self, config, embedding_weight=None):
super(DeepseekV2LMHeadPipe, self).__init__(config, embedding_weight=embedding_weight)

Check warning on line 333 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L333

Added line #L333 was not covered by tests

@property
def embedding_weight(self):
Expand Down Expand Up @@ -406,6 +414,10 @@
assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support"

virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1)
use_dualpipev = getattr(self.config, "use_dualpipev", False)
if use_dualpipev:
assert LocalSharedLayerDesc is not None, "LocalSharedLayerDesc is None, please update your paddle."
shared_class = LocalSharedLayerDesc if use_dualpipev else SharedLayerDesc

Check warning on line 420 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L417-L420

Added lines #L417 - L420 were not covered by tests

def get_hcg():
return fleet.get_hybrid_communicate_group()
Expand All @@ -420,7 +432,7 @@

if config.tie_word_embeddings:
self.add_sequential_layer(
SharedLayerDesc(
shared_class(
"DeepseekV2_shared_weight",
DeepseekV2EmbeddingPipe,
shared_weight_attr="embedding_weight",
Expand Down Expand Up @@ -453,12 +465,11 @@

if config.tie_word_embeddings:
self.add_sequential_layer(
SharedLayerDesc(
shared_class(
"DeepseekV2_shared_weight",
DeepseekV2LMHeadPipe,
shared_weight_attr="embedding_weight",
config=config,
**{"transpose_y": True},
),
"lm_head",
)
Expand Down Expand Up @@ -489,6 +500,7 @@
"partition": False,
},
num_virtual_pipeline_stages=virtual_pp_degree,
use_dualpipev=use_dualpipev,
)
# You should call init here, since there is a diamond inheritance problem
self.apply(self._init_weights)
Expand All @@ -497,3 +509,39 @@

def get_loss_fn(self, config):
return DeepseekV2PretrainingCriterionPipe(config)

def overlapped_forward_backward(
self,
module0, # the module of the forward chunk
inputs0,
criterion0,
labels0,
module1, # the module of the backward chunk, maybe not used
loss1,
outputs1,
output_grads1,
scaler,
):
outputs0 = inputs0
for layer in module0:
outputs0 = layer(outputs0)

Check warning on line 527 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L525-L527

Added lines #L525 - L527 were not covered by tests

outputs0 = [outputs0] if isinstance(outputs0, paddle.Tensor) else outputs0

Check warning on line 529 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L529

Added line #L529 was not covered by tests

if labels0 is not None:
loss0 = criterion0(outputs0, labels0)

Check warning on line 532 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L531-L532

Added lines #L531 - L532 were not covered by tests
else:
loss0 = None

Check warning on line 534 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L534

Added line #L534 was not covered by tests

if loss1 is not None:
if scaler:
paddle.autograd.backward(scaler.scale(loss1))

Check warning on line 538 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L536-L538

Added lines #L536 - L538 were not covered by tests
else:
paddle.autograd.backward(loss1)

Check warning on line 540 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L540

Added line #L540 was not covered by tests
else:
paddle.autograd.backward(

Check warning on line 542 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L542

Added line #L542 was not covered by tests
tensors=outputs1,
grad_tensors=output_grads1,
)

return outputs0, loss0

Check warning on line 547 in paddlenlp/transformers/deepseek_v2/modeling_pp.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/deepseek_v2/modeling_pp.py#L547

Added line #L547 was not covered by tests
2 changes: 2 additions & 0 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,8 @@
Tensor: The pretraining loss. Its data type should be float32 and its shape is [1].

"""
if isinstance(prediction_scores, list):
prediction_scores = prediction_scores[0]

Check warning on line 1354 in paddlenlp/transformers/gpt/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/gpt/modeling.py#L1354

Added line #L1354 was not covered by tests
with paddle.amp.auto_cast(False):
masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2))
# skip ignore_index which loss == 0
Expand Down
Loading
Loading