Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Jan 16, 2025
1 parent 0514c0c commit e0ce48e
Show file tree
Hide file tree
Showing 6 changed files with 220 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def model_can_use_jetstream_pt(model_path: str) -> bool:
"""
config = AutoConfig.from_pretrained(model_path)
# For now few models are supported
supported_models = ["llama", "gemma", "mixtral"]
supported_models = ["llama", "gemma", "mixtral", "qwen2"]
if config.model_type not in supported_models:
return False
if jetstream_pt_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers import AutoConfig

from .compatibility import model_can_use_jetstream_pt
from .models import GemmaModel, LlamaModel, MixtralModel
from .models import GemmaModel, LlamaModel, MixtralModel, Qwen2Model


class OptimumJetstreamEngine(PyTorchEngine):
Expand Down Expand Up @@ -66,6 +66,8 @@ def load_model_info(config: "PretrainedConfig") -> Any:
model_class = GemmaModel
elif config.model_type == "mixtral":
model_class = MixtralModel
elif config.model_type == "qwen2":
model_class = Qwen2Model
else:
raise ValueError(f"Unsupported model type {config.model_type}")
model_info = fetch_models.ModelInfo(
Expand Down Expand Up @@ -101,7 +103,7 @@ def create_engine_env_data(
head_dim_shardable = model_info.num_kv_heads == 1 and model_info.head_dim % num_devices == 0

if num_kv_heads_shardable or head_dim_shardable:
shard_on_batch = False
shard_on_batch = False
else:
shard_on_batch = True
aligned_batch_size = (batch_size + num_devices - 1) // num_devices * num_devices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _token_encode(self, text: str, max_length: int) -> Tuple[jnp.ndarray, int]:
tokens, true_length = pad_tokens(input_ids[0],
self.tokenizer.bos_token_id,
self.tokenizer.pad_token_id,
is_bos=True,
is_bos=(self.tokenizer.bos_token_id is not None),
max_prefill_length=max_prefill_length,
jax_padding=True,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .gemma_model_hf import GemmaModelHf as GemmaModel
from .llama_model_exportable_hf import TransformerHf as LlamaModel
from .mixtral_model_hf import MixtralModelHf as MixtralModel
from .qwen2_model import Qwen2Model
Loading

0 comments on commit e0ce48e

Please sign in to comment.