From e0ce48e0fa51862afd1a7344b0248bf7b5cc900a Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Thu, 16 Jan 2025 16:39:54 +0000 Subject: [PATCH] WIP --- .../jetstream_pt_support/compatibility.py | 2 +- .../jetstream_pt_support/engine_loader.py | 6 +- .../jetstream_pt_support/generator.py | 2 +- .../jetstream_pt_support/models/__init__.py | 1 + .../models/qwen2_model.py | 292 ++++++++++++------ .../tests/test_decode_jetstream.py | 8 +- 6 files changed, 220 insertions(+), 91 deletions(-) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py index d1fc325d..456eac26 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/compatibility.py @@ -25,7 +25,7 @@ def model_can_use_jetstream_pt(model_path: str) -> bool: """ config = AutoConfig.from_pretrained(model_path) # For now few models are supported - supported_models = ["llama", "gemma", "mixtral"] + supported_models = ["llama", "gemma", "mixtral", "qwen2"] if config.model_type not in supported_models: return False if jetstream_pt_available(): diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py index 24bd600a..f5110a02 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/engine_loader.py @@ -24,7 +24,7 @@ from transformers import AutoConfig from .compatibility import model_can_use_jetstream_pt -from .models import GemmaModel, LlamaModel, MixtralModel +from .models import GemmaModel, LlamaModel, MixtralModel, Qwen2Model class OptimumJetstreamEngine(PyTorchEngine): @@ -66,6 +66,8 @@ def load_model_info(config: "PretrainedConfig") -> Any: model_class = GemmaModel elif config.model_type == "mixtral": model_class = MixtralModel + elif config.model_type == "qwen2": + model_class = Qwen2Model else: raise ValueError(f"Unsupported model type {config.model_type}") model_info = fetch_models.ModelInfo( @@ -101,7 +103,7 @@ def create_engine_env_data( head_dim_shardable = model_info.num_kv_heads == 1 and model_info.head_dim % num_devices == 0 if num_kv_heads_shardable or head_dim_shardable: - shard_on_batch = False + shard_on_batch = False else: shard_on_batch = True aligned_batch_size = (batch_size + num_devices - 1) // num_devices * num_devices diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py index 1baed358..64148eea 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/generator.py @@ -407,7 +407,7 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]: tokens, true_length = pad_tokens(input_ids[0], self.tokenizer.bos_token_id, self.tokenizer.pad_token_id, - is_bos=True, + is_bos=(self.tokenizer.bos_token_id is not None), max_prefill_length=max_prefill_length, jax_padding=True, ) diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py index 9855bde6..4a835fe1 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/__init__.py @@ -1,3 +1,4 @@ from .gemma_model_hf import GemmaModelHf as GemmaModel from .llama_model_exportable_hf import TransformerHf as LlamaModel from .mixtral_model_hf import MixtralModelHf as MixtralModel +from .qwen2_model import Qwen2Model diff --git a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py index 16e17a98..02b5b126 100644 --- a/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py +++ b/text-generation-inference/server/text_generation_server/jetstream_pt_support/models/qwen2_model.py @@ -2,22 +2,25 @@ Qwen2 model implementation, based on Jetstream implementation of Llama model. """ -from typing import Any, List, Optional import copy +from typing import Any, List, Optional + import jax -import math import torch import torch.nn.functional as F -from jetstream_pt.model_base import ModuleBase from jetstream_pt.layers import ( - Attention, - RMSNorm, - get_quantized_embedding_layer, - get_quantized_linear_layer, + AttentionKernel, + Int8KVAttentionKernel, + RMSNorm, + apply_rotary_emb, + get_quantized_embedding_layer, + get_quantized_linear_layer, ) -from torch import nn +from jetstream_pt.model_base import ModuleBase -from . import model_args +# Use llama's functions and classes that are the same as in Qwen2 +from jetstream_pt.third_party.llama.model_exportable import model_args +from transformers import GenerationConfig, GenerationMixin, Qwen2Config class FeedForward(ModuleBase): @@ -82,8 +85,145 @@ def forward(self, x): result = self.w2(F.silu(self.w1(x)) * self.w3(x)) return result +class QwenAttention(ModuleBase): + """Attention module.""" + + def __init__( + self, n_heads, n_kv_heads, head_dim, hidden_size, device, env, layer_id + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = head_dim + self.n_rep = self.n_heads // self.n_kv_heads + self.env = env + self.hidden_size = hidden_size + self.layer_id = layer_id + + LinearLayer = get_quantized_linear_layer(env.quant_config) + linear_kwargs = {} + if LinearLayer != torch.nn.Linear: + linear_kwargs = {"quant_config": env.quant_config} -class TransformerBlock(ModuleBase): + self.wo = LinearLayer( + n_heads * self.head_dim, + hidden_size, + bias=False, + device=device, + **linear_kwargs, + ) + + Kernel = ( + Int8KVAttentionKernel + if env.quant_config.enable_kv_quantization + else AttentionKernel + ) + self.attention_kernel = Kernel(env, self.layer_id) + + self.q_size = n_heads * self.head_dim + self.kv_size = self.n_kv_heads * self.head_dim + if self.env.qkv_fusion: + self._register_load_state_dict_pre_hook(self.load_hook) + self.wqkv = LinearLayer( + hidden_size, + (n_heads + 2 * self.n_kv_heads) * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + else: + self.wq = LinearLayer( + hidden_size, + n_heads * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + self.wk = LinearLayer( + hidden_size, + self.n_kv_heads * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + self.wv = LinearLayer( + hidden_size, + self.n_kv_heads * self.head_dim, + bias=True, + device=device, + **linear_kwargs, + ) + + def load_hook(self, state_dict, prefix, *args): + if prefix + "wq.weight" in state_dict: + wq = state_dict.pop(prefix + "wq.weight") + wk = state_dict.pop(prefix + "wk.weight") + wv = state_dict.pop(prefix + "wv.weight") + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + cache, + start=None, + end=None, + ragged_batch_index=None, + ragged_block_index=None, + ): + with jax.named_scope("attn_linear_before_cache"): + bsz, seqlen = x.shape[0], x.shape[-2] + + # qkv fuse + if self.env.qkv_fusion: + xq, xk, xv = self.wqkv(x).split( + [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + else: + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim) + xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim) + + shard_axis = 0 if self.env.shard_on_batch else 2 + self.env.apply_sharding(xq, axis=shard_axis) + self.env.apply_sharding(xk, axis=shard_axis) + self.env.apply_sharding(xv, axis=shard_axis) + + with jax.named_scope("attn_rope"): + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + xk = xk.transpose(1, 2) + xv = xv.transpose(1, 2) + xq = xq.transpose(1, 2) + + if mask.ndim == 2: + if seqlen == 1: + mask = mask[:, None, None, :] + else: + mask = mask[None, None, :, :] + + # if cache is not None and cache.cache_k is not None: + # print(f"xq {xq.shape} xk {xk.shape} cache shape {cache.cache_k.shape}") + output = self.attention_kernel( + xq=xq, + xk=xk, + xv=xv, + mask=mask, + # cache[self.layer_id], + cache=cache, + start=start, + end=end, + ragged_batch_index=ragged_batch_index, + ragged_block_index=ragged_block_index, + ).type_as(xq) + # print(f"output {output.shape}") + output = output.transpose(-3, -2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) + + +class Qwen2DecoderLayer(ModuleBase): """Transformer block.""" def __init__( @@ -99,7 +239,7 @@ def __init__( self.head_dim = args.dim // args.n_heads self.args = args - self.attention = Attention( + self.attention = QwenAttention( args.n_heads, args.n_kv_heads or args.n_heads, args.dim // args.n_heads, @@ -132,6 +272,10 @@ def __init__( self.attention.annotate_sharding("wk.weight", 0) self.attention.annotate_sharding("wv.weight", 0) self.attention.annotate_sharding("wo.weight", 1) + self.attention.annotate_sharding("wq.weight.bias", 0) + self.attention.annotate_sharding("wk.weight.bias", 0) + self.attention.annotate_sharding("wv.weight.bias", 0) + self.attention.annotate_sharding("wo.weight.bias", -1) self.hf_name("feed_forward", "mlp") self.hf_name("attention_norm", "input_layernorm") @@ -168,71 +312,71 @@ def forward( return out -def apply_scaling(freqs: torch.Tensor, config: model_args.RopeScalingArgs): - # Values obtained from grid search - scale_factor = config.factor - low_freq_factor = config.low_freq_factor - high_freq_factor = config.high_freq_factor - old_context_len = config.original_max_position_embeddings - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - - def precompute_freqs_cis( dim: int, end: int, theta: float = 10000.0, - rope_scaling_config: model_args.RopeScalingArgs = None, ): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if rope_scaling_config is not None: - freqs = apply_scaling(freqs, rope_scaling_config) freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis -class Transformer(ModuleBase): - """Transformer module.""" +class Qwen2Model(ModuleBase, GenerationMixin): + """Qwen2 module.""" def __init__( self, - params: model_args.ModelArgs, + config: Qwen2Config, + device, env, ): + if config.sliding_window is not None: + raise ValueError("Sliding window is not supported for Qwen2 model") + if config.rope_scaling is not None: + raise ValueError("Rope scaling is not supported for Qwen2 model") + super().__init__() + self.config = config + self.generation_config = GenerationConfig.from_model_config(config) + + # NOTE: these parameters are deduced from the config's intermediate_size and hidden_size, so to be compatible + # with the original Jestream/Pytorch model. + ffn_dim_multiplier = config.intermediate_size / int(8 * config.hidden_size / 3) + multiple_of = 1 + params = model_args.ModelArgs( + dim=config.hidden_size, + n_layers=config.num_hidden_layers, + n_heads=config.num_attention_heads, + n_kv_heads=config.num_key_value_heads, + vocab_size=config.vocab_size, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + norm_eps=config.rms_norm_eps, + max_seq_len=env.cache_len, + bf16_enable=env.bf16_enable, + rope_theta=config.rope_theta, + ) + params.device = device self.env = env + self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers + self.vocab_size = config.vocab_size + self.n_layers = config.num_hidden_layers Embedding = get_quantized_embedding_layer(env.quant_config) self.tok_embeddings = Embedding( - params.vocab_size, - params.dim, - device=params.device, + config.vocab_size, + config.hidden_size, + device=device, ) self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params, env)) - self.norm = RMSNorm(params.dim, eps=params.norm_eps, device=params.device) + for layer_id in range(config.num_hidden_layers): + self.layers.append(Qwen2DecoderLayer(layer_id, params, env)) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=params.device) LinearLayer = get_quantized_linear_layer(env.quant_config) linear_kwargs = {} @@ -240,18 +384,16 @@ def __init__( linear_kwargs["quant_config"] = env.quant_config self.output = LinearLayer( - params.dim, - params.vocab_size, + config.hidden_size, + config.vocab_size, bias=False, device=params.device, **linear_kwargs, ) - # TODO what to do with this freqs_cis = precompute_freqs_cis( - self.params.dim // self.params.n_heads, + config.hidden_size // config.num_attention_heads, self.params.max_seq_len * 2, - theta=self.params.rope_theta, - rope_scaling_config=self.params.rope_scaling_args, + theta=self.config.rms_norm_eps, ) self.register_buffer("freqs_cis", freqs_cis) @@ -319,36 +461,6 @@ def forward( output = self.output(h).float() return output - @classmethod - def from_hf_model_id(cls, model_id, env, is_tiny=False): - if is_tiny: - name = "llama-2-tiny" - else: - name = { - "meta-llama/Llama-2-7b-chat-hf": "llama-2-7b", - "meta-llama/Llama-2-7b-hf": "llama-2-7b", - "meta-llama/Llama-2-13b-chat-hf": "llama-2-13b", - "meta-llama/Llama-2-13b-hf": "llama-2-13b", - "meta-llama/Llama-2-70b-hf": "llama-2-70b", - "meta-llama/Llama-2-70b-chat-hf": "llama-2-70b", - "meta-llama/Meta-Llama-3-8B": "llama-3-8b", - "meta-llama/Meta-Llama-3-8B-Instruct": "llama-3-8b", - "meta-llama/Meta-Llama-3-70B": "llama-3-70b", - "meta-llama/Meta-Llama-3-70B-Instruct": "llama-3-70b", - "meta-llama/Llama-3.1-8B": "llama-3.1-8b", - "meta-llama/Llama-3.1-8B-Instruct": "llama-3.1-8b", - "meta-llama/Llama-3.2-1B": "llama-3.2-1b", - "meta-llama/Llama-3.2-1B-Instruct": "llama-3.2-1b", - "meta-llama/Llama-3.3-70B": "llama-3.3-70b", - "meta-llama/Llama-3.3-70B-Instruct": "llama-3.3-70b", - }.get(model_id) - assert name - args = model_args.get_model_args( - name, env.cache_len, env.batch_size, env.bf16_enable - ) - args.device = "meta" - model = cls(args, env) - return model def convert_hf_weights(self, hf_weights): @@ -363,6 +475,8 @@ def transform(val, n_heads): updated = copy.copy(hf_weights) for key, value in hf_weights.items(): + if "bias" in key: + continue if "q_proj" in key: updated[key] = transform(value, self.params.n_heads) if "k_proj" in key: @@ -372,3 +486,9 @@ def transform(val, n_heads): res = super().convert_hf_weights(updated) res["freqs_cis"] = self.freqs_cis return res + + @classmethod + def from_config(cls, config, env): + device = "meta" + model = cls(config, device, env) + return model diff --git a/text-generation-inference/tests/test_decode_jetstream.py b/text-generation-inference/tests/test_decode_jetstream.py index 9bf72947..79e4ba20 100644 --- a/text-generation-inference/tests/test_decode_jetstream.py +++ b/text-generation-inference/tests/test_decode_jetstream.py @@ -70,9 +70,15 @@ def test_decode_single_jetstream_pytorch_slow(params, do_sample): sequence_length=256, expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city", max_new_tokens=20, + ), + DecodeTestParams( + model_id="Qwen/Qwen2.5-0.5B", + sequence_length=256, + expected_text=" Winston Smith, his chin nuzzled into his breast, stretched, and looked out across the city", + max_new_tokens=20, ) ], - ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny", "Trendyol-LLM-7b-base-v0.1", "Llama-3.2-1B"], + ids=["TinyLLama-v0", "gemma-2b", "Mixtral-tiny", "Trendyol-LLM-7b-base-v0.1", "Llama-3.2-1B", "Qwen2.5-0.5B"], ) def test_decode_single_jetstream_pytorch(params, do_sample): params.do_sample = do_sample