diff --git a/README.md b/README.md index b49c511316ee..60b58b3b7de8 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ | [ChatGLM3](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/chatglm2) | THUDM/chatglm3-6b | | [DeepSeekV2](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V2, deepseek-ai/DeepSeek-V2-Chat, deepseek-ai/DeepSeek-V2-Lite, deepseek-ai/DeepSeek-V2-Lite-Chat, deepseek-ai/DeepSeek-Coder-V2-Base, deepseek-ai/DeepSeek-Coder-V2-Instruct, deepseek-ai/DeepSeek-Coder-V2-Lite-Base, deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct | | [DeepSeekV3](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-V3, deepseek-ai/DeepSeek-V3-Base | +| [DeepSeek-R1](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/config/deepseek-v2) | deepseek-ai/DeepSeek-R1, deepseek-ai/DeepSeek-R1-Zero, deepseek-ai/DeepSeek-R1-Distill-Llama-70B, deepseek-ai/DeepSeek-R1-Distill-Llama-8B, deepseek-ai/DeepSeek-R1-Distill-Qwen-14B, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B, deepseek-ai/DeepSeek-R1-Distill-Qwen-32B, deepseek-ai/DeepSeek-R1-Distill-Qwen-7B | | [Gemma](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/gemma) | google/gemma-7b, google/gemma-7b-it, google/gemma-2b, google/gemma-2b-it | | [Mistral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mistral) | mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-7B-v0.1 | | [Mixtral](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/config/mixtral) | mistralai/Mixtral-8x7B-Instruct-v0.1 | diff --git a/paddlenlp/transformers/conversion_utils.py b/paddlenlp/transformers/conversion_utils.py index ed8333d1c53f..d4258baa1c34 100644 --- a/paddlenlp/transformers/conversion_utils.py +++ b/paddlenlp/transformers/conversion_utils.py @@ -1311,15 +1311,15 @@ def _get_tensor_parallel_mappings(cls, config: PretrainedConfig, is_split=True) def _resolve_prefix_keys(state_keys_base, state_keys_real, ignore_error=False): # state_keys_map base to real state_keys_map = {} - - state_keys_base = set(state_keys_base) + # sorted by length,match from long to short for A.key B.key ... + state_keys_base = sorted(state_keys_base, key=lambda x: len(x), reverse=True) state_keys_real = set(state_keys_real) for key in state_keys_base: for x in state_keys_real: if x.endswith(key): state_keys_map[key] = x - # break # remove break for math A.key B.key ... + break if key not in state_keys_map: if not ignore_error: logger.debug(f"tensor parallel conversion: could not find name {key} in loaded state dict!") diff --git a/paddlenlp/transformers/deepseek_v2/modeling.py b/paddlenlp/transformers/deepseek_v2/modeling.py index 64ad9384af43..9ff75b988257 100644 --- a/paddlenlp/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/transformers/deepseek_v2/modeling.py @@ -567,7 +567,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): class DeepseekV2MLP(nn.Layer): - def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None): + def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): super().__init__() self.config = config self.hidden_size = config.hidden_size if hidden_size is None else hidden_size @@ -580,7 +580,7 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size ColumnParallelLinear = linear_utils.ColumnParallelLinear RowParallelLinear = linear_utils.RowParallelLinear - if config.tensor_parallel_degree > 1: + if config.tensor_parallel_degree > 1 and not is_moe: self.gate_proj = ColumnParallelLinear( self.hidden_size, self.intermediate_size, @@ -753,14 +753,14 @@ def __init__(self, config): self.ep_rank = 0 self.experts = nn.LayerList( [ - DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size) + DeepseekV2MLP(config, intermediate_size=config.moe_intermediate_size, is_moe=True) for i in range(config.n_routed_experts) ] ) self.gate = MoEGate(config) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) + self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size, is_moe=True) def forward(self, hidden_states): identity = hidden_states @@ -1158,7 +1158,8 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi ["embed_tokens.weight"], ["norm.weight"], ] - for layer_index in range(config.num_hidden_layers): + # last one layer contains MTP (eagle) parameters for inference + for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers): layer_mappings = [ [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], [f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"], @@ -1178,6 +1179,7 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi # MoE parameters model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"]) for expert_idx in range(config.n_routed_experts): expert_mappings = [ [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], @@ -1189,6 +1191,15 @@ def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMappi model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"]) model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"]) + # MTP (eagle) parameters for inference + if layer_index >= config.num_hidden_layers: + model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"]) + model_mappings.append([f"layers.{layer_index}.enorm.weight"]) + model_mappings.append([f"layers.{layer_index}.hnorm.weight"]) + model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"]) + model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"]) + init_name_mappings(mappings=model_mappings) if cls.base_model_class.__name__ not in config.architectures: for mapping in model_mappings: @@ -1251,6 +1262,21 @@ def get_tensor_parallel_split_mappings(num_layers): final_actions[key.replace("layers.0.", f"layers.{i}.")] = action final_actions[key] = action + # for MTP (eagle) parameters for inference + base_actions.pop("embed_tokens.weight") + base_actions.pop("lm_head.weight") + base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False) + base_actions["layers.0.eh_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range( + config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers + ): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key] = action + return final_actions mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers)