From c02003925b4313ba2da46884e6ac0b3fbc067e47 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Mon, 29 Jul 2024 16:10:23 +0800 Subject: [PATCH] add mlp for gemma2 (#11678) --- .../llm/src/ipex_llm/transformers/convert.py | 4 +++- .../src/ipex_llm/transformers/models/common.py | 18 ++++++++++++++++++ .../src/ipex_llm/transformers/models/gemma2.py | 7 ++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index c12a78bd866..274ce467dd4 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1513,11 +1513,13 @@ def _optimize_post(model, lightweight_bmm=False): from ipex_llm.transformers.models.gemma import gemma_rms_norm_forward from ipex_llm.transformers.models.gemma2 import gemma2_attention_forward from ipex_llm.transformers.models.gemma2 import gemma2_model_forward + from ipex_llm.transformers.models.gemma2 import gemma2_mlp_forward from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm, Gemma2Attention - from transformers.models.gemma2.modeling_gemma2 import Gemma2Model + from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2MLP convert_forward(model, Gemma2RMSNorm, gemma_rms_norm_forward) convert_forward(model, Gemma2Attention, gemma2_attention_forward) convert_forward(model, Gemma2Model, gemma2_model_forward) + convert_forward(model, Gemma2MLP, gemma2_mlp_forward) elif model.config.model_type == "Yi": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index d32e2ce46a2..e1522c4e957 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -41,3 +41,21 @@ def merge_qkv_base(module: torch.nn.Module, attention_class): ]) module.qkv_proj = qkv_proj del module.q_proj, module.k_proj, module.v_proj + + +def fuse_mlp_base(module: torch.nn.Module, act: int, x: torch.Tensor): + from ipex_llm.transformers.models.utils import mlp_fusion_check + x_2d = x.view(-1, x.size(-1)) + qtype = getattr(module.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, module.training): + import xe_linear + x_2d = x_2d.contiguous() + return module.down_proj( + xe_linear.mlp_forward_xpu( + x_2d, module.gate_proj.weight.data, module.up_proj.weight.data, + x_2d.size(0), x_2d.size(1), module.gate_proj.out_len, + act, qtype + ) + ) + else: + return module.down_proj(module.act_fn(module.gate_proj(x)) * module.up_proj(x)) diff --git a/python/llm/src/ipex_llm/transformers/models/gemma2.py b/python/llm/src/ipex_llm/transformers/models/gemma2.py index d6c3af5291a..33201864223 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma2.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma2.py @@ -34,7 +34,8 @@ import torch from typing import Optional, Tuple -from ipex_llm.transformers.models.common import merge_qkv_base +from ipex_llm.transformers.models.common import merge_qkv_base, fuse_mlp_base +from ipex_llm.transformers.models.utils import GELU from ipex_llm.transformers.models.utils import should_use_fuse_rope, use_sdp, use_sdp_causal from transformers.cache_utils import Cache from transformers.models.gemma2.modeling_gemma2 import Gemma2Model, Gemma2Attention @@ -177,3 +178,7 @@ def gemma2_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def gemma2_mlp_forward(self, x: torch.Tensor): + return fuse_mlp_base(self, GELU, x)