diff --git a/mlx_vlm/models/phi3_v/__init__.py b/mlx_vlm/models/phi3_v/__init__.py new file mode 100644 index 0000000..6e0acf1 --- /dev/null +++ b/mlx_vlm/models/phi3_v/__init__.py @@ -0,0 +1,8 @@ +from .phi3_v import ( + LanguageModel, + Model, + ModelConfig, + TextConfig, + VisionConfig, + VisionModel, +) diff --git a/mlx_vlm/models/phi3_v/language.py b/mlx_vlm/models/phi3_v/language.py new file mode 100644 index 0000000..f2401bf --- /dev/null +++ b/mlx_vlm/models/phi3_v/language.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +import inspect + + +@dataclass +class TextConfig: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LanguageModel: + pass diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py new file mode 100644 index 0000000..5fb2c32 --- /dev/null +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -0,0 +1,235 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union +from types import SimpleNamespace +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .su_rope import Phi3SuScaledRotaryEmbedding +from .language import TextConfig, LanguageModel +from .vision import VisionConfig, VisionModel + + +@dataclass +class ModelConfig: + text_config: TextConfig + vision_config: VisionConfig + model_type: str + vocab_size: int + + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + + ignore_index: int = -100 + image_token_index: int = 257152 + hidden_size: int = 2048 + pad_token_id: int = 0 + + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + max_position_embeddings: int = 131072 + original_max_position_embeddings: int = 4096 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Attention(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.num_hidden_layers = args.num_hidden_layers + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) + self.qkv_proj = nn.Linear(dim, op_size, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "su": + self.rope = Phi3SuScaledRotaryEmbedding( + head_dim, + traditional=False, + base=args.rope_theta, + scale=rope_scale, + max_position_embeddings=args.max_position_embeddings, + original_max_position_embeddings=args.original_max_position_embeddings, + short_factor=args.rope_scaling["short_factor"], + long_factor=args.rope_scaling["long_factor"], + ) + else: + if args.rope_scaling and args.rope_scaling["type"] == "linear": + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + qkv = self.qkv_proj(x) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + offset = cache[0].shape[2] + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + keys = mx.concatenate([cache[0], keys], axis=2) + values = mx.concatenate([cache[1], values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x) -> mx.array: + x = self.gate_up_proj(x) + gate, x = mx.split(x, 2, axis=-1) + return self.down_proj(nn.silu(gate) * x) + + +class TransformerBlock(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Phi3V(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.vision_embed_tokens = VisionModel(args) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + pixel_values=None, + image_sizes=None, + cache=None, + ): + # print('inputs', inputs) # debug + h = self.embed_tokens(inputs) + p = np.argwhere(inputs < 0).tolist() + if pixel_values is not None: + h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(h, mask, cache[i]) + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.model_type = args.model_type + self.model = Phi3V(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.config = args + + def __call__( + self, + inputs: mx.array, + pixel_values=None, + mask=None, + cache=None, + ): + out, cache = self.model(inputs, pixel_values, mask, cache) + return self.lm_head(out).astype(self.lm_head.weight.dtype), cache + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads + + @property + def language_model(self): + return self + + @property + def vision_model(self): + return self.model.vision_embed_tokens diff --git a/mlx_vlm/models/phi3_v/su_rope.py b/mlx_vlm/models/phi3_v/su_rope.py new file mode 100644 index 0000000..83dbdfe --- /dev/null +++ b/mlx_vlm/models/phi3_v/su_rope.py @@ -0,0 +1,70 @@ +import math +import mlx.core as mx + + +class Phi3SuScaledRotaryEmbedding: + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000.0, + scale: float = 1.0, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + short_factor: list[float] | float = 1.0, + long_factor: list[float] | float = 1.0, + ): + """ + Phi3Su Scaled Rotary Embedding layer for Phi-3 models. + + Args: + dims (int): The feature dimensions to be rotated. + traditional (bool, optional): Unused. Default: ``False``. + base (int, optional): Base for the exponential scaling. + scale (float, optional): The scale used to scale the positions. Default: 1.0. + max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. + original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. + short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. + long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. + """ + self.inv_freq_short = 1.0 / ( + mx.array(short_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.inv_freq_long = 1.0 / ( + scale + * mx.array(long_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = math.sqrt( + 1 + + math.log(max_position_embeddings / original_max_position_embeddings) + / math.log(original_max_position_embeddings) + ) + + def _get_cos_sin(self, offset, L): + position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] + inv_freq = ( + self.inv_freq_long + if position_ids.max() + 1 > self.original_max_position_embeddings + else self.inv_freq_short + ) + inv_freq_expanded = mx.repeat( + inv_freq[None, :, None], position_ids.shape[0], axis=0 + ) + position_ids_expanded = position_ids[:, None, :] + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) * self.scaling_factor + sin = mx.sin(emb) * self.scaling_factor + return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) + + def __call__(self, x, offset: int = 0): + def _rotate_half(_x): + midpoint = _x.shape[-1] // 2 + x1, x2 = _x[..., :midpoint], _x[..., midpoint:] + return mx.concatenate([-x2, x1], axis=-1) + + cos, sin = self._get_cos_sin(offset, x.shape[2]) + return (x * cos) + (_rotate_half(x) * sin) diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py new file mode 100644 index 0000000..f832e2d --- /dev/null +++ b/mlx_vlm/models/phi3_v/vision.py @@ -0,0 +1,325 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Optional +from types import SimpleNamespace + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + + +@dataclass +class VisionConfig: + model_type: str = "phi3_v" + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + image_dim_out: int = (1024,) + model_name: str = "openai/clip-vit-large-patch14-336" + name: str = "clip_vision_model" + num_img_tokens: int = 144 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def check_array_shape(arr): + shape = arr.shape + + # Check if the shape has 4 dimensions + if len(shape) != 4: + return False + + out_channels, kH, KW, _ = shape + + # Check if out_channels is the largest, and kH and KW are the same + if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): + return True + else: + return False + + +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads = num_heads + head_dim = dims // num_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(output) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) + + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding(position_ids) + return embeddings + + +class ClipModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.model_type = config.model_type + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + + +class ClipVModel(nn.Module): + def __init__(self, config): + super().__init__() + self.model_type = config.model_type + self.vision_model = ClipModel(config) + + +class VisionModel(nn.Module): + CLIP_VIT_LARGE_PATCH14_336_CONFIG = SimpleNamespace( + model_type="phi3_v", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + ) + + def __init__(self, config): + super().__init__() + self.model_type = config.model_type + self.img_processor = ClipVModel(self.CLIP_VIT_LARGE_PATCH14_336_CONFIG) + self.image_dim_out = image_dim_out = 1024 + self.glb_GN = mx.zeros([1, 1, image_dim_out * 4]) + self.sub_GN = mx.zeros([1, 1, 1, image_dim_out * 4]) + self.img_projection = [ + nn.Linear(image_dim_out * 4, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ] + + def __call__( + self, + img_embeds, + txt_embeds=None, + img_sizes=None, + positions=None, + output_hidden_states=None, + ): + if output_hidden_states: + return self.img_processor.vision_model( + img_embeds, output_hidden_states=output_hidden_states + ) + # print(0, txt_embeds.shape, img_embeds.shape, img_sizes.shape) + img_embeds = mx.array(img_embeds) + img_sizes = mx.array(img_sizes) + B = img_embeds.shape[0] + img_sizes = (img_sizes // 336).tolist() + img_features = self.img_processor.vision_model( + img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True + )[-1][-2][:, 1:] + img_features = img_features.reshape(B, -1, *img_features.shape[1:]) + C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5) + output_imgs, output_len = [], [] + for _bs in range(B): + h, w = img_sizes[_bs] + B_ = h * w + + def _reshape_and_concatenate(img, shape, tile_shape): + return mx.concatenate( + [ + img.reshape(shape) + .transpose(0, 1, 3, 2, 4, 5) + .reshape(tile_shape), + mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1)), + ], + axis=2, + ).reshape(1, -1, 4 * C) + + glb_img = _reshape_and_concatenate( + img_features[_bs, :1], + (1, H // 2, 2, H // 2, 2, C), + (1, H // 2, H // 2, 4 * C), + ) + sub_img = _reshape_and_concatenate( + img_features[_bs, 1 : B_ + 1], + (B_, H // 2, 2, H // 2, 2, C), + (1, h * 12, w * 12, 4 * C), + ) + x = mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1) + for l in self.img_projection: + x = l(x) + output_imgs.append(np.array(x.astype(mx.float32))) + output_len.append(int((h * w + 1) * 144 + 1 + (h + 1) * 12)) + idx = 0 + txt_embeds = np.array(txt_embeds.astype(mx.float32)) + for i, cnt in enumerate(output_len): + txt_embeds[ + positions[idx][0], positions[idx][1] : positions[idx][1] + cnt + ] = output_imgs[i] + idx += cnt + txt_embeds = mx.array(txt_embeds) + return txt_embeds + + def sanitize(self, weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + continue + elif "patch_embedding.weight" in k: + if check_array_shape(v): + sanitized_weights[k] = v + else: + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 613995d..a709ca8 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -468,6 +468,156 @@ def test_multi_modality(self): args.vision_config.num_channels, (args.vision_config.image_size, args.vision_config.image_size), ) + def test_phi3_v(self): + from mlx_vlm.models import phi3_v + + text_config = phi3_v.TextConfig() + + vision_config = phi3_v.VisionConfig( + model_type="phi3_v", + image_dim_out=1024, + model_name="openai/clip-vit-large-patch14-336", + name="clip_vision_model", + num_img_tokens=144, + ) + + args = phi3_v.ModelConfig( + text_config=text_config, + vision_config=vision_config, + **{ + "hidden_size": 3072, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "phi3_v", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281, + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997, + ], + "type": "su", + }, + "rope_theta": 10000.0, + "vocab_size": 32064, + }, + ) + + model = phi3_v.Model(args) + + self.language_test_runner( + model.language_model, + args.model_type, + args.vocab_size, + args.num_hidden_layers, + ) + + self.vision_test_runner( + model.vision_model, + args.vision_config.model_type, + args.vision_config.hidden_size, + args.vision_config.num_channels, + (args.vision_config.image_size, args.vision_config.image_size), + ) if __name__ == "__main__": diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 1b44586..42998a1 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -159,6 +159,9 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: config["text_config"] = text_config if model_type == "idefics2": config = AutoConfig.from_pretrained(model_path).to_dict() + if model_type == "phi3_v": + config["vision_config"] = config['img_processor'] + config["text_config"] = {} model_config = model_class.ModelConfig.from_dict(config) @@ -705,6 +708,8 @@ def prepare_inputs(image_processor, processor, image, prompt, image_token_index) pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) mask = mx.array(inputs["attention_mask"]) + if 'image_sizes' in inputs: + return input_ids, pixel_values, inputs['image_sizes'] return input_ids, pixel_values, mask