From 4a4823d1dd15aabc3ac93f6aef6a48f080790b73 Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Wed, 5 Jun 2024 19:40:10 +0900 Subject: [PATCH 1/9] phi3_v into mlx_vlm --- mlx_vlm/models/phi3_v/__init__.py | 8 + mlx_vlm/models/phi3_v/language.py | 17 ++ mlx_vlm/models/phi3_v/phi3vision.py | 229 ++++++++++++++++++++++ mlx_vlm/models/phi3_v/su_rope.py | 70 +++++++ mlx_vlm/models/phi3_v/vision.py | 283 ++++++++++++++++++++++++++++ mlx_vlm/utils.py | 5 + 6 files changed, 612 insertions(+) create mode 100644 mlx_vlm/models/phi3_v/__init__.py create mode 100644 mlx_vlm/models/phi3_v/language.py create mode 100644 mlx_vlm/models/phi3_v/phi3vision.py create mode 100644 mlx_vlm/models/phi3_v/su_rope.py create mode 100644 mlx_vlm/models/phi3_v/vision.py diff --git a/mlx_vlm/models/phi3_v/__init__.py b/mlx_vlm/models/phi3_v/__init__.py new file mode 100644 index 0000000..ce4363e --- /dev/null +++ b/mlx_vlm/models/phi3_v/__init__.py @@ -0,0 +1,8 @@ +from .phi3vision import ( + LanguageModel, + Model, + ModelConfig, + TextConfig, + VisionConfig, + VisionModel, +) \ No newline at end of file diff --git a/mlx_vlm/models/phi3_v/language.py b/mlx_vlm/models/phi3_v/language.py new file mode 100644 index 0000000..8efd488 --- /dev/null +++ b/mlx_vlm/models/phi3_v/language.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass +import inspect + +@dataclass +class TextConfig: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + +class LanguageModel: + pass \ No newline at end of file diff --git a/mlx_vlm/models/phi3_v/phi3vision.py b/mlx_vlm/models/phi3_v/phi3vision.py new file mode 100644 index 0000000..9643b2d --- /dev/null +++ b/mlx_vlm/models/phi3_v/phi3vision.py @@ -0,0 +1,229 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union +from types import SimpleNamespace +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .su_rope import Phi3SuScaledRotaryEmbedding +from .language import TextConfig, LanguageModel +from .vision import VisionConfig, VisionModel + +@dataclass +class ModelConfig: + text_config: TextConfig + vision_config: VisionConfig + model_type: str + vocab_size: int + + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + + ignore_index: int = -100 + image_token_index: int = 257152 + hidden_size: int = 2048 + pad_token_id: int = 0 + + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + max_position_embeddings: int = 131072 + original_max_position_embeddings: int = 4096 + + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + +class Attention(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.num_hidden_layers = args.num_hidden_layers + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) + self.qkv_proj = nn.Linear(dim, op_size, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "su": + self.rope = Phi3SuScaledRotaryEmbedding( + head_dim, + traditional=False, + base=args.rope_theta, + scale=rope_scale, + max_position_embeddings=args.max_position_embeddings, + original_max_position_embeddings=args.original_max_position_embeddings, + short_factor=args.rope_scaling["short_factor"], + long_factor=args.rope_scaling["long_factor"], + ) + else: + if args.rope_scaling and args.rope_scaling["type"] == "linear": + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + qkv = self.qkv_proj(x) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + offset = cache[0].shape[2] + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + keys = mx.concatenate([cache[0], keys], axis=2) + values = mx.concatenate([cache[1], values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x) -> mx.array: + x = self.gate_up_proj(x) + gate, x = mx.split(x, 2, axis=-1) + return self.down_proj(nn.silu(gate) * x) + + +class TransformerBlock(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Phi3V(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.vision_embed_tokens = VisionModel(args) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + pixel_values=None, + image_sizes=None, + cache=None, + ): + h = self.embed_tokens(inputs) + p = np.argwhere(inputs < 0).tolist() + if pixel_values is not None: + x = self.vision_embed_tokens(h, pixel_values, image_sizes, p) + mask=None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(h, mask, cache[i]) + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.model_type = args.model_type + self.model = Phi3V(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.config = args + + def __call__( + self, + inputs: mx.array, + pixel_values=None, + mask=None, + cache=None, + ): + out, cache = self.model(inputs, pixel_values, mask, cache) + return self.lm_head(out), cache + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads + + @property + def language_model(self): + return self diff --git a/mlx_vlm/models/phi3_v/su_rope.py b/mlx_vlm/models/phi3_v/su_rope.py new file mode 100644 index 0000000..1ea75a0 --- /dev/null +++ b/mlx_vlm/models/phi3_v/su_rope.py @@ -0,0 +1,70 @@ +import math +import mlx.core as mx + + +class Phi3SuScaledRotaryEmbedding: + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000.0, + scale: float = 1.0, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + short_factor: list[float] | float = 1.0, + long_factor: list[float] | float = 1.0, + ): + """ + Phi3Su Scaled Rotary Embedding layer for Phi-3 models. + + Args: + dims (int): The feature dimensions to be rotated. + traditional (bool, optional): Unused. Default: ``False``. + base (int, optional): Base for the exponential scaling. + scale (float, optional): The scale used to scale the positions. Default: 1.0. + max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. + original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. + short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. + long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. + """ + self.inv_freq_short = 1.0 / ( + mx.array(short_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.inv_freq_long = 1.0 / ( + scale + * mx.array(long_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = math.sqrt( + 1 + + math.log(max_position_embeddings / original_max_position_embeddings) + / math.log(original_max_position_embeddings) + ) + + def _get_cos_sin(self, offset, L): + position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] + inv_freq = ( + self.inv_freq_long + if position_ids.max() + 1 > self.original_max_position_embeddings + else self.inv_freq_short + ) + inv_freq_expanded = mx.repeat( + inv_freq[None, :, None], position_ids.shape[0], axis=0 + ) + position_ids_expanded = position_ids[:, None, :] + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) * self.scaling_factor + sin = mx.sin(emb) * self.scaling_factor + return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) + + def __call__(self, x, offset: int = 0): + def _rotate_half(_x): + midpoint = _x.shape[-1] // 2 + x1, x2 = _x[..., :midpoint], _x[..., midpoint:] + return mx.concatenate([-x2, x1], axis=-1) + + cos, sin = self._get_cos_sin(offset, x.shape[2]) + return (x * cos) + (_rotate_half(x) * sin) \ No newline at end of file diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py new file mode 100644 index 0000000..0ba7cf4 --- /dev/null +++ b/mlx_vlm/models/phi3_v/vision.py @@ -0,0 +1,283 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Optional +from types import SimpleNamespace + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +@dataclass +class VisionConfig: + model_type: str = 'clip' + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + image_dim_out: int = 1024, + model_name: str ='openai/clip-vit-large-patch14-336', + name: str ='clip_vision_model', + num_img_tokens: int =144 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def check_array_shape(arr): + shape = arr.shape + + # Check if the shape has 4 dimensions + if len(shape) != 4: + return False + + out_channels, kH, KW, _ = shape + + # Check if out_channels is the largest, and kH and KW are the same + if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): + return True + else: + return False + + +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads = num_heads + head_dim = dims // num_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(output) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) + + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding(position_ids) + return embeddings + + +class ClipModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + +class ClipVModel(nn.Module): + def __init__(self, config): + super().__init__() + self.vision_model = ClipModel(config) + +class VisionModel(nn.Module): + CLIP_VIT_LARGE_PATCH14_336_CONFIG = SimpleNamespace( + hidden_size=1024, + image_size=336, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + ) + + def __init__(self, config): + super().__init__() + self.img_processor = ClipVModel(self.CLIP_VIT_LARGE_PATCH14_336_CONFIG) + self.image_dim_out = image_dim_out = 1024 + self.glb_GN = mx.zeros([1, 1, image_dim_out * 4]) + self.sub_GN = mx.zeros([1, 1, 1, image_dim_out * 4]) + self.img_projection = [nn.Linear(image_dim_out * 4, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size)] + + def __call__(self, txt_embeds, img_embeds, img_sizes, positions): + # print(0, txt_embeds.shape, img_embeds.shape, img_sizes.shape) + img_embeds = mx.array(img_embeds) + img_sizes = mx.array(img_sizes) + B = img_embeds.shape[0] + img_sizes = (img_sizes // 336).tolist() + img_features = self.img_processor.vision_model(img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True)[-1][-2][:,1:] + img_features = img_features.reshape(B, -1, *img_features.shape[1:]) + C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5) + output_imgs, output_len = [], [] + for _bs in range(B): + h, w = img_sizes[_bs] + B_ = h * w + def _reshape_and_concatenate(img, shape, tile_shape): + return mx.concatenate([img.reshape(shape).transpose(0, 1, 3, 2, 4, 5).reshape(tile_shape), mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1))], axis=2).reshape(1, -1, 4 * C) + glb_img = _reshape_and_concatenate( img_features[_bs, :1], (1, H//2, 2, H//2, 2, C), (1, H//2, H//2, 4*C) ) + sub_img = _reshape_and_concatenate( img_features[_bs, 1:B_+1], (B_, H//2, 2, H//2, 2, C), (1, h*12, w*12, 4*C) ) + x = mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1) + for l in self.img_projection: + x = l(x) + output_imgs.append(np.array(x.astype(mx.float32))) + output_len.append(int((h*w + 1) * 144 + 1 + (h + 1) * 12)) + idx = 0 + txt_embeds = np.array(txt_embeds.astype(mx.float32)) + for i, cnt in enumerate(output_len): + print(cnt) + txt_embeds[positions[idx][0], positions[idx][1] : positions[idx][1] + cnt] = output_imgs[i] + idx += cnt + txt_embeds = mx.array(txt_embeds) + return txt_embeds + + + def sanitize(self, weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + continue + elif "patch_embedding.weight" in k: + if check_array_shape(v): + sanitized_weights[k] = v + else: + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights \ No newline at end of file diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index cb8c853..6767018 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -156,6 +156,9 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: config["text_config"] = text_config if model_type == "idefics2": config = AutoConfig.from_pretrained(model_path).to_dict() + if model_type == "phi3_v": + config["vision_config"] = config['img_processor'] + config["text_config"] = {} model_config = model_class.ModelConfig.from_dict(config) @@ -690,6 +693,8 @@ def prepare_inputs(image_processor, processor, image, prompt, image_token_index) pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) mask = mx.array(inputs["attention_mask"]) + if 'image_sizes' in inputs: + return input_ids, pixel_values, inputs['image_sizes'] return input_ids, pixel_values, mask From 942eeff562b9bfbf1a4582d9ae25095d61152f40 Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:11:16 +0900 Subject: [PATCH 2/9] Update test_models.py --- mlx_vlm/models/phi3_v/__init__.py | 4 +- mlx_vlm/models/phi3_v/language.py | 4 +- .../phi3_v/{phi3vision.py => phi3_v.py} | 14 +- mlx_vlm/models/phi3_v/su_rope.py | 2 +- mlx_vlm/models/phi3_v/vision.py | 80 +++++++--- mlx_vlm/tests/test_models.py | 151 ++++++++++++++++++ 6 files changed, 228 insertions(+), 27 deletions(-) rename mlx_vlm/models/phi3_v/{phi3vision.py => phi3_v.py} (96%) diff --git a/mlx_vlm/models/phi3_v/__init__.py b/mlx_vlm/models/phi3_v/__init__.py index ce4363e..6e0acf1 100644 --- a/mlx_vlm/models/phi3_v/__init__.py +++ b/mlx_vlm/models/phi3_v/__init__.py @@ -1,8 +1,8 @@ -from .phi3vision import ( +from .phi3_v import ( LanguageModel, Model, ModelConfig, TextConfig, VisionConfig, VisionModel, -) \ No newline at end of file +) diff --git a/mlx_vlm/models/phi3_v/language.py b/mlx_vlm/models/phi3_v/language.py index 8efd488..f2401bf 100644 --- a/mlx_vlm/models/phi3_v/language.py +++ b/mlx_vlm/models/phi3_v/language.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import inspect + @dataclass class TextConfig: @classmethod @@ -13,5 +14,6 @@ def from_dict(cls, params): } ) + class LanguageModel: - pass \ No newline at end of file + pass diff --git a/mlx_vlm/models/phi3_v/phi3vision.py b/mlx_vlm/models/phi3_v/phi3_v.py similarity index 96% rename from mlx_vlm/models/phi3_v/phi3vision.py rename to mlx_vlm/models/phi3_v/phi3_v.py index 9643b2d..5fb2c32 100644 --- a/mlx_vlm/models/phi3_v/phi3vision.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -13,6 +13,7 @@ from .language import TextConfig, LanguageModel from .vision import VisionConfig, VisionModel + @dataclass class ModelConfig: text_config: TextConfig @@ -37,7 +38,6 @@ class ModelConfig: max_position_embeddings: int = 131072 original_max_position_embeddings: int = 4096 - @classmethod def from_dict(cls, params): return cls( @@ -48,6 +48,7 @@ def from_dict(cls, params): } ) + class Attention(nn.Module): def __init__(self, args: TextConfig): super().__init__() @@ -179,11 +180,12 @@ def __call__( image_sizes=None, cache=None, ): + # print('inputs', inputs) # debug h = self.embed_tokens(inputs) p = np.argwhere(inputs < 0).tolist() if pixel_values is not None: - x = self.vision_embed_tokens(h, pixel_values, image_sizes, p) - mask=None + h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) + mask = None if h.shape[1] > 1: mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) mask = mask.astype(h.dtype) @@ -210,7 +212,7 @@ def __call__( cache=None, ): out, cache = self.model(inputs, pixel_values, mask, cache) - return self.lm_head(out), cache + return self.lm_head(out).astype(self.lm_head.weight.dtype), cache @property def layers(self): @@ -227,3 +229,7 @@ def n_kv_heads(self): @property def language_model(self): return self + + @property + def vision_model(self): + return self.model.vision_embed_tokens diff --git a/mlx_vlm/models/phi3_v/su_rope.py b/mlx_vlm/models/phi3_v/su_rope.py index 1ea75a0..83dbdfe 100644 --- a/mlx_vlm/models/phi3_v/su_rope.py +++ b/mlx_vlm/models/phi3_v/su_rope.py @@ -67,4 +67,4 @@ def _rotate_half(_x): return mx.concatenate([-x2, x1], axis=-1) cos, sin = self._get_cos_sin(offset, x.shape[2]) - return (x * cos) + (_rotate_half(x) * sin) \ No newline at end of file + return (x * cos) + (_rotate_half(x) * sin) diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py index 0ba7cf4..f832e2d 100644 --- a/mlx_vlm/models/phi3_v/vision.py +++ b/mlx_vlm/models/phi3_v/vision.py @@ -8,9 +8,10 @@ import mlx.nn as nn import numpy as np + @dataclass class VisionConfig: - model_type: str = 'clip' + model_type: str = "phi3_v" num_hidden_layers: int = 24 hidden_size: int = 1024 intermediate_size: int = 4096 @@ -21,11 +22,11 @@ class VisionConfig: vocab_size: int = 32000 num_channels: int = 3 layer_norm_eps: float = 1e-5 - image_dim_out: int = 1024, - model_name: str ='openai/clip-vit-large-patch14-336', - name: str ='clip_vision_model', - num_img_tokens: int =144 - + image_dim_out: int = (1024,) + model_name: str = "openai/clip-vit-large-patch14-336" + name: str = "clip_vision_model" + num_img_tokens: int = 144 + @classmethod def from_dict(cls, params): return cls( @@ -187,6 +188,7 @@ def __call__(self, x: mx.array) -> mx.array: class ClipModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() + self.model_type = config.model_type self.embeddings = VisionEmbeddings(config) self.pre_layrnorm = nn.LayerNorm(config.hidden_size) self.encoder = Encoder(config) @@ -209,14 +211,18 @@ def __call__( pooler_output = self.post_layernorm(x[:, 0, :]) return pooler_output, x, encoder_states - + + class ClipVModel(nn.Module): def __init__(self, config): super().__init__() + self.model_type = config.model_type self.vision_model = ClipModel(config) + class VisionModel(nn.Module): CLIP_VIT_LARGE_PATCH14_336_CONFIG = SimpleNamespace( + model_type="phi3_v", hidden_size=1024, image_size=336, intermediate_size=4096, @@ -225,48 +231,84 @@ class VisionModel(nn.Module): num_channels=3, num_hidden_layers=24, patch_size=14, - ) + ) def __init__(self, config): super().__init__() + self.model_type = config.model_type self.img_processor = ClipVModel(self.CLIP_VIT_LARGE_PATCH14_336_CONFIG) self.image_dim_out = image_dim_out = 1024 self.glb_GN = mx.zeros([1, 1, image_dim_out * 4]) self.sub_GN = mx.zeros([1, 1, 1, image_dim_out * 4]) - self.img_projection = [nn.Linear(image_dim_out * 4, config.hidden_size), nn.GELU(), nn.Linear(config.hidden_size, config.hidden_size)] + self.img_projection = [ + nn.Linear(image_dim_out * 4, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ] - def __call__(self, txt_embeds, img_embeds, img_sizes, positions): + def __call__( + self, + img_embeds, + txt_embeds=None, + img_sizes=None, + positions=None, + output_hidden_states=None, + ): + if output_hidden_states: + return self.img_processor.vision_model( + img_embeds, output_hidden_states=output_hidden_states + ) # print(0, txt_embeds.shape, img_embeds.shape, img_sizes.shape) img_embeds = mx.array(img_embeds) img_sizes = mx.array(img_sizes) B = img_embeds.shape[0] img_sizes = (img_sizes // 336).tolist() - img_features = self.img_processor.vision_model(img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True)[-1][-2][:,1:] + img_features = self.img_processor.vision_model( + img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True + )[-1][-2][:, 1:] img_features = img_features.reshape(B, -1, *img_features.shape[1:]) C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5) output_imgs, output_len = [], [] for _bs in range(B): h, w = img_sizes[_bs] B_ = h * w + def _reshape_and_concatenate(img, shape, tile_shape): - return mx.concatenate([img.reshape(shape).transpose(0, 1, 3, 2, 4, 5).reshape(tile_shape), mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1))], axis=2).reshape(1, -1, 4 * C) - glb_img = _reshape_and_concatenate( img_features[_bs, :1], (1, H//2, 2, H//2, 2, C), (1, H//2, H//2, 4*C) ) - sub_img = _reshape_and_concatenate( img_features[_bs, 1:B_+1], (B_, H//2, 2, H//2, 2, C), (1, h*12, w*12, 4*C) ) + return mx.concatenate( + [ + img.reshape(shape) + .transpose(0, 1, 3, 2, 4, 5) + .reshape(tile_shape), + mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1)), + ], + axis=2, + ).reshape(1, -1, 4 * C) + + glb_img = _reshape_and_concatenate( + img_features[_bs, :1], + (1, H // 2, 2, H // 2, 2, C), + (1, H // 2, H // 2, 4 * C), + ) + sub_img = _reshape_and_concatenate( + img_features[_bs, 1 : B_ + 1], + (B_, H // 2, 2, H // 2, 2, C), + (1, h * 12, w * 12, 4 * C), + ) x = mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1) for l in self.img_projection: x = l(x) output_imgs.append(np.array(x.astype(mx.float32))) - output_len.append(int((h*w + 1) * 144 + 1 + (h + 1) * 12)) + output_len.append(int((h * w + 1) * 144 + 1 + (h + 1) * 12)) idx = 0 txt_embeds = np.array(txt_embeds.astype(mx.float32)) for i, cnt in enumerate(output_len): - print(cnt) - txt_embeds[positions[idx][0], positions[idx][1] : positions[idx][1] + cnt] = output_imgs[i] + txt_embeds[ + positions[idx][0], positions[idx][1] : positions[idx][1] + cnt + ] = output_imgs[i] idx += cnt txt_embeds = mx.array(txt_embeds) return txt_embeds - def sanitize(self, weights): sanitized_weights = {} for k, v in weights.items(): @@ -280,4 +322,4 @@ def sanitize(self, weights): else: sanitized_weights[k] = v - return sanitized_weights \ No newline at end of file + return sanitized_weights diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 9514709..7701f81 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -331,6 +331,157 @@ def test_paligemma(self): (args.vision_config.image_size, args.vision_config.image_size), ) + def test_phi3_v(self): + from mlx_vlm.models import phi3_v + + text_config = phi3_v.TextConfig() + + vision_config = phi3_v.VisionConfig( + model_type="phi3_v", + image_dim_out=1024, + model_name="openai/clip-vit-large-patch14-336", + name="clip_vision_model", + num_img_tokens=144, + ) + + args = phi3_v.ModelConfig( + text_config=text_config, + vision_config=vision_config, + **{ + "hidden_size": 3072, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "phi3_v", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281, + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997, + ], + "type": "su", + }, + "rope_theta": 10000.0, + "vocab_size": 32064, + }, + ) + + model = phi3_v.Model(args) + + self.language_test_runner( + model.language_model, + args.model_type, + args.vocab_size, + args.num_hidden_layers, + ) + + self.vision_test_runner( + model.vision_model, + args.vision_config.model_type, + args.vision_config.hidden_size, + args.vision_config.num_channels, + (args.vision_config.image_size, args.vision_config.image_size), + ) + if __name__ == "__main__": unittest.main() From adda5fbc8006168c33d215ae5d59abaa15a0a637 Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Sun, 23 Jun 2024 23:50:19 +0900 Subject: [PATCH 3/9] copy pastes of past files into this new rebase --- mlx_vlm/models/phi3_v/__init__.py | 8 + mlx_vlm/models/phi3_v/language.py | 19 ++ mlx_vlm/models/phi3_v/phi3_v.py | 235 +++++++++++++++++++++ mlx_vlm/models/phi3_v/su_rope.py | 70 +++++++ mlx_vlm/models/phi3_v/vision.py | 325 ++++++++++++++++++++++++++++++ mlx_vlm/tests/test_models.py | 150 ++++++++++++++ mlx_vlm/utils.py | 5 + 7 files changed, 812 insertions(+) create mode 100644 mlx_vlm/models/phi3_v/__init__.py create mode 100644 mlx_vlm/models/phi3_v/language.py create mode 100644 mlx_vlm/models/phi3_v/phi3_v.py create mode 100644 mlx_vlm/models/phi3_v/su_rope.py create mode 100644 mlx_vlm/models/phi3_v/vision.py diff --git a/mlx_vlm/models/phi3_v/__init__.py b/mlx_vlm/models/phi3_v/__init__.py new file mode 100644 index 0000000..6e0acf1 --- /dev/null +++ b/mlx_vlm/models/phi3_v/__init__.py @@ -0,0 +1,8 @@ +from .phi3_v import ( + LanguageModel, + Model, + ModelConfig, + TextConfig, + VisionConfig, + VisionModel, +) diff --git a/mlx_vlm/models/phi3_v/language.py b/mlx_vlm/models/phi3_v/language.py new file mode 100644 index 0000000..f2401bf --- /dev/null +++ b/mlx_vlm/models/phi3_v/language.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +import inspect + + +@dataclass +class TextConfig: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LanguageModel: + pass diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py new file mode 100644 index 0000000..5fb2c32 --- /dev/null +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -0,0 +1,235 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union +from types import SimpleNamespace +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from .su_rope import Phi3SuScaledRotaryEmbedding +from .language import TextConfig, LanguageModel +from .vision import VisionConfig, VisionModel + + +@dataclass +class ModelConfig: + text_config: TextConfig + vision_config: VisionConfig + model_type: str + vocab_size: int + + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + + ignore_index: int = -100 + image_token_index: int = 257152 + hidden_size: int = 2048 + pad_token_id: int = 0 + + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + max_position_embeddings: int = 131072 + original_max_position_embeddings: int = 4096 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class Attention(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + self.num_hidden_layers = args.num_hidden_layers + + self.head_dim = head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) + self.qkv_proj = nn.Linear(dim, op_size, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + + rope_scale = 1.0 + if args.rope_scaling and args.rope_scaling["type"] == "su": + self.rope = Phi3SuScaledRotaryEmbedding( + head_dim, + traditional=False, + base=args.rope_theta, + scale=rope_scale, + max_position_embeddings=args.max_position_embeddings, + original_max_position_embeddings=args.original_max_position_embeddings, + short_factor=args.rope_scaling["short_factor"], + long_factor=args.rope_scaling["long_factor"], + ) + else: + if args.rope_scaling and args.rope_scaling["type"] == "linear": + rope_scale = 1 / args.rope_scaling["factor"] + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + qkv = self.qkv_proj(x) + query_pos = self.n_heads * self.head_dim + queries, keys, values = mx.split( + qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 + ) + + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + offset = cache[0].shape[2] + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + keys = mx.concatenate([cache[0], keys], axis=2) + values = mx.concatenate([cache[1], values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + + def __call__(self, x) -> mx.array: + x = self.gate_up_proj(x) + gate, x = mx.split(x, 2, axis=-1) + return self.down_proj(nn.silu(gate) * x) + + +class TransformerBlock(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class Phi3V(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.vision_embed_tokens = VisionModel(args) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + pixel_values=None, + image_sizes=None, + cache=None, + ): + # print('inputs', inputs) # debug + h = self.embed_tokens(inputs) + p = np.argwhere(inputs < 0).tolist() + if pixel_values is not None: + h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + if cache is None: + cache = [None] * len(self.layers) + for i, layer in enumerate(self.layers): + h, cache[i] = layer(h, mask, cache[i]) + return self.norm(h), cache + + +class Model(nn.Module): + def __init__(self, args: TextConfig): + super().__init__() + self.model_type = args.model_type + self.model = Phi3V(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + self.config = args + + def __call__( + self, + inputs: mx.array, + pixel_values=None, + mask=None, + cache=None, + ): + out, cache = self.model(inputs, pixel_values, mask, cache) + return self.lm_head(out).astype(self.lm_head.weight.dtype), cache + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.args.hidden_size // self.args.num_attention_heads + + @property + def n_kv_heads(self): + return self.args.num_key_value_heads + + @property + def language_model(self): + return self + + @property + def vision_model(self): + return self.model.vision_embed_tokens diff --git a/mlx_vlm/models/phi3_v/su_rope.py b/mlx_vlm/models/phi3_v/su_rope.py new file mode 100644 index 0000000..83dbdfe --- /dev/null +++ b/mlx_vlm/models/phi3_v/su_rope.py @@ -0,0 +1,70 @@ +import math +import mlx.core as mx + + +class Phi3SuScaledRotaryEmbedding: + def __init__( + self, + dims: int, + traditional: bool = False, + base: float = 10000.0, + scale: float = 1.0, + max_position_embeddings: int = 131072, + original_max_position_embeddings: int = 4096, + short_factor: list[float] | float = 1.0, + long_factor: list[float] | float = 1.0, + ): + """ + Phi3Su Scaled Rotary Embedding layer for Phi-3 models. + + Args: + dims (int): The feature dimensions to be rotated. + traditional (bool, optional): Unused. Default: ``False``. + base (int, optional): Base for the exponential scaling. + scale (float, optional): The scale used to scale the positions. Default: 1.0. + max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. + original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. + short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. + long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. + """ + self.inv_freq_short = 1.0 / ( + mx.array(short_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.inv_freq_long = 1.0 / ( + scale + * mx.array(long_factor, dtype=mx.float32) + * base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) + ) + self.original_max_position_embeddings = original_max_position_embeddings + self.scaling_factor = math.sqrt( + 1 + + math.log(max_position_embeddings / original_max_position_embeddings) + / math.log(original_max_position_embeddings) + ) + + def _get_cos_sin(self, offset, L): + position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] + inv_freq = ( + self.inv_freq_long + if position_ids.max() + 1 > self.original_max_position_embeddings + else self.inv_freq_short + ) + inv_freq_expanded = mx.repeat( + inv_freq[None, :, None], position_ids.shape[0], axis=0 + ) + position_ids_expanded = position_ids[:, None, :] + freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) + emb = mx.concatenate([freqs, freqs], axis=-1) + cos = mx.cos(emb) * self.scaling_factor + sin = mx.sin(emb) * self.scaling_factor + return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) + + def __call__(self, x, offset: int = 0): + def _rotate_half(_x): + midpoint = _x.shape[-1] // 2 + x1, x2 = _x[..., :midpoint], _x[..., midpoint:] + return mx.concatenate([-x2, x1], axis=-1) + + cos, sin = self._get_cos_sin(offset, x.shape[2]) + return (x * cos) + (_rotate_half(x) * sin) diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py new file mode 100644 index 0000000..f832e2d --- /dev/null +++ b/mlx_vlm/models/phi3_v/vision.py @@ -0,0 +1,325 @@ +import inspect +import math +from dataclasses import dataclass +from typing import Optional +from types import SimpleNamespace + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + + +@dataclass +class VisionConfig: + model_type: str = "phi3_v" + num_hidden_layers: int = 24 + hidden_size: int = 1024 + intermediate_size: int = 4096 + num_attention_heads: int = 16 + image_size: int = 336 + patch_size: int = 14 + projection_dim: int = 768 + vocab_size: int = 32000 + num_channels: int = 3 + layer_norm_eps: float = 1e-5 + image_dim_out: int = (1024,) + model_name: str = "openai/clip-vit-large-patch14-336" + name: str = "clip_vision_model" + num_img_tokens: int = 144 + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def check_array_shape(arr): + shape = arr.shape + + # Check if the shape has 4 dimensions + if len(shape) != 4: + return False + + out_channels, kH, KW, _ = shape + + # Check if out_channels is the largest, and kH and KW are the same + if (out_channels >= kH) and (out_channels >= KW) and (kH == KW): + return True + else: + return False + + +class Attention(nn.Module): + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + bias: bool = False, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + "The input feature dimensions should be divisible by the " + f"number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads = num_heads + head_dim = dims // num_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(query_input_dims, dims, bias=bias) + self.k_proj = nn.Linear(key_input_dims, dims, bias=bias) + self.v_proj = nn.Linear(value_input_dims, value_dims, bias=bias) + self.out_proj = nn.Linear(value_dims, value_output_dims, bias=bias) + + def __call__(self, queries, keys, values, mask=None): + queries = self.q_proj(queries) + keys = self.k_proj(keys) + values = self.v_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(output) + + +class MLP(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.activation_fn = nn.GELU(approx="fast") + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def __call__(self, x: mx.array) -> mx.array: + x = self.activation_fn(self.fc1(x)) + x = self.fc2(x) + return x + + +class EncoderLayer(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Attention( + config.hidden_size, config.num_attention_heads, bias=True + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = MLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def __call__(self, x: mx.array, mask: Optional[mx.array] = None) -> mx.array: + y = self.layer_norm1(x) + y = self.self_attn(y, y, y, mask) + x = x + y + y = self.layer_norm2(x) + y = self.mlp(y) + return x + y + + +class Encoder(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.layers = [EncoderLayer(config) for _ in range(config.num_hidden_layers)] + + +class VisionEmbeddings(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = mx.zeros((config.hidden_size,)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def __call__(self, x: mx.array) -> mx.array: + batch_size = x.shape[0] + patch_embeddings = self.patch_embedding(x) + patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) + + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + embeddings += self.position_embedding(position_ids) + return embeddings + + +class ClipModel(nn.Module): + def __init__(self, config: VisionConfig): + super().__init__() + self.model_type = config.model_type + self.embeddings = VisionEmbeddings(config) + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + self.encoder = Encoder(config) + self.post_layernorm = nn.LayerNorm(config.hidden_size) + + def __call__( + self, + x: mx.array, + output_hidden_states: Optional[bool] = None, + ) -> mx.array: + x = self.embeddings(x) + x = self.pre_layrnorm(x) + + encoder_states = (x,) if output_hidden_states else None + + for l in self.encoder.layers: + x = l(x, mask=None) + if output_hidden_states: + encoder_states = encoder_states + (x,) + + pooler_output = self.post_layernorm(x[:, 0, :]) + return pooler_output, x, encoder_states + + +class ClipVModel(nn.Module): + def __init__(self, config): + super().__init__() + self.model_type = config.model_type + self.vision_model = ClipModel(config) + + +class VisionModel(nn.Module): + CLIP_VIT_LARGE_PATCH14_336_CONFIG = SimpleNamespace( + model_type="phi3_v", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + layer_norm_eps=1e-05, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + ) + + def __init__(self, config): + super().__init__() + self.model_type = config.model_type + self.img_processor = ClipVModel(self.CLIP_VIT_LARGE_PATCH14_336_CONFIG) + self.image_dim_out = image_dim_out = 1024 + self.glb_GN = mx.zeros([1, 1, image_dim_out * 4]) + self.sub_GN = mx.zeros([1, 1, 1, image_dim_out * 4]) + self.img_projection = [ + nn.Linear(image_dim_out * 4, config.hidden_size), + nn.GELU(), + nn.Linear(config.hidden_size, config.hidden_size), + ] + + def __call__( + self, + img_embeds, + txt_embeds=None, + img_sizes=None, + positions=None, + output_hidden_states=None, + ): + if output_hidden_states: + return self.img_processor.vision_model( + img_embeds, output_hidden_states=output_hidden_states + ) + # print(0, txt_embeds.shape, img_embeds.shape, img_sizes.shape) + img_embeds = mx.array(img_embeds) + img_sizes = mx.array(img_sizes) + B = img_embeds.shape[0] + img_sizes = (img_sizes // 336).tolist() + img_features = self.img_processor.vision_model( + img_embeds.reshape(-1, *img_embeds.shape[2:]).transpose(0, 2, 3, 1), True + )[-1][-2][:, 1:] + img_features = img_features.reshape(B, -1, *img_features.shape[1:]) + C, H = self.image_dim_out, int(img_features.shape[2] ** 0.5) + output_imgs, output_len = [], [] + for _bs in range(B): + h, w = img_sizes[_bs] + B_ = h * w + + def _reshape_and_concatenate(img, shape, tile_shape): + return mx.concatenate( + [ + img.reshape(shape) + .transpose(0, 1, 3, 2, 4, 5) + .reshape(tile_shape), + mx.tile(self.sub_GN, (1, tile_shape[1], 1, 1)), + ], + axis=2, + ).reshape(1, -1, 4 * C) + + glb_img = _reshape_and_concatenate( + img_features[_bs, :1], + (1, H // 2, 2, H // 2, 2, C), + (1, H // 2, H // 2, 4 * C), + ) + sub_img = _reshape_and_concatenate( + img_features[_bs, 1 : B_ + 1], + (B_, H // 2, 2, H // 2, 2, C), + (1, h * 12, w * 12, 4 * C), + ) + x = mx.concatenate([sub_img, self.glb_GN, glb_img], axis=1) + for l in self.img_projection: + x = l(x) + output_imgs.append(np.array(x.astype(mx.float32))) + output_len.append(int((h * w + 1) * 144 + 1 + (h + 1) * 12)) + idx = 0 + txt_embeds = np.array(txt_embeds.astype(mx.float32)) + for i, cnt in enumerate(output_len): + txt_embeds[ + positions[idx][0], positions[idx][1] : positions[idx][1] + cnt + ] = output_imgs[i] + idx += cnt + txt_embeds = mx.array(txt_embeds) + return txt_embeds + + def sanitize(self, weights): + sanitized_weights = {} + for k, v in weights.items(): + if "position_ids" in k: + continue + elif "patch_embedding.weight" in k: + if check_array_shape(v): + sanitized_weights[k] = v + else: + sanitized_weights[k] = v.transpose(0, 2, 3, 1) + else: + sanitized_weights[k] = v + + return sanitized_weights diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 613995d..a709ca8 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -468,6 +468,156 @@ def test_multi_modality(self): args.vision_config.num_channels, (args.vision_config.image_size, args.vision_config.image_size), ) + def test_phi3_v(self): + from mlx_vlm.models import phi3_v + + text_config = phi3_v.TextConfig() + + vision_config = phi3_v.VisionConfig( + model_type="phi3_v", + image_dim_out=1024, + model_name="openai/clip-vit-large-patch14-336", + name="clip_vision_model", + num_img_tokens=144, + ) + + args = phi3_v.ModelConfig( + text_config=text_config, + vision_config=vision_config, + **{ + "hidden_size": 3072, + "intermediate_size": 8192, + "max_position_embeddings": 131072, + "model_type": "phi3_v", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "long_factor": [ + 1.0299999713897705, + 1.0499999523162842, + 1.0499999523162842, + 1.0799999237060547, + 1.2299998998641968, + 1.2299998998641968, + 1.2999999523162842, + 1.4499999284744263, + 1.5999999046325684, + 1.6499998569488525, + 1.8999998569488525, + 2.859999895095825, + 3.68999981880188, + 5.419999599456787, + 5.489999771118164, + 5.489999771118164, + 9.09000015258789, + 11.579999923706055, + 15.65999984741211, + 15.769999504089355, + 15.789999961853027, + 18.360000610351562, + 21.989999771118164, + 23.079999923706055, + 30.009998321533203, + 32.35000228881836, + 32.590003967285156, + 35.56000518798828, + 39.95000457763672, + 53.840003967285156, + 56.20000457763672, + 57.95000457763672, + 59.29000473022461, + 59.77000427246094, + 59.920005798339844, + 61.190006256103516, + 61.96000671386719, + 62.50000762939453, + 63.3700065612793, + 63.48000717163086, + 63.48000717163086, + 63.66000747680664, + 63.850006103515625, + 64.08000946044922, + 64.760009765625, + 64.80001068115234, + 64.81001281738281, + 64.81001281738281, + ], + "short_factor": [ + 1.05, + 1.05, + 1.05, + 1.1, + 1.1, + 1.1, + 1.2500000000000002, + 1.2500000000000002, + 1.4000000000000004, + 1.4500000000000004, + 1.5500000000000005, + 1.8500000000000008, + 1.9000000000000008, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.000000000000001, + 2.1000000000000005, + 2.1000000000000005, + 2.2, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3499999999999996, + 2.3999999999999995, + 2.3999999999999995, + 2.6499999999999986, + 2.6999999999999984, + 2.8999999999999977, + 2.9499999999999975, + 3.049999999999997, + 3.049999999999997, + 3.049999999999997, + ], + "type": "su", + }, + "rope_theta": 10000.0, + "vocab_size": 32064, + }, + ) + + model = phi3_v.Model(args) + + self.language_test_runner( + model.language_model, + args.model_type, + args.vocab_size, + args.num_hidden_layers, + ) + + self.vision_test_runner( + model.vision_model, + args.vision_config.model_type, + args.vision_config.hidden_size, + args.vision_config.num_channels, + (args.vision_config.image_size, args.vision_config.image_size), + ) if __name__ == "__main__": diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 1b44586..42998a1 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -159,6 +159,9 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: config["text_config"] = text_config if model_type == "idefics2": config = AutoConfig.from_pretrained(model_path).to_dict() + if model_type == "phi3_v": + config["vision_config"] = config['img_processor'] + config["text_config"] = {} model_config = model_class.ModelConfig.from_dict(config) @@ -705,6 +708,8 @@ def prepare_inputs(image_processor, processor, image, prompt, image_token_index) pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) mask = mx.array(inputs["attention_mask"]) + if 'image_sizes' in inputs: + return input_ids, pixel_values, inputs['image_sizes'] return input_ids, pixel_values, mask From 45338b1a5b4869fb304044e006b9501d8725e721 Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Mon, 24 Jun 2024 01:39:57 +0000 Subject: [PATCH 4/9] precommit --- mlx_vlm/models/phi3_v/language.py | 2 +- mlx_vlm/models/phi3_v/phi3_v.py | 5 ++--- mlx_vlm/models/phi3_v/su_rope.py | 1 + mlx_vlm/models/phi3_v/vision.py | 2 +- mlx_vlm/tests/test_models.py | 1 + mlx_vlm/utils.py | 6 +++--- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mlx_vlm/models/phi3_v/language.py b/mlx_vlm/models/phi3_v/language.py index f2401bf..763f188 100644 --- a/mlx_vlm/models/phi3_v/language.py +++ b/mlx_vlm/models/phi3_v/language.py @@ -1,5 +1,5 @@ -from dataclasses import dataclass import inspect +from dataclasses import dataclass @dataclass diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py index 5fb2c32..bd56c76 100644 --- a/mlx_vlm/models/phi3_v/phi3_v.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -1,16 +1,15 @@ import inspect import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union from types import SimpleNamespace -from typing import Optional +from typing import Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn import numpy as np +from .language import LanguageModel, TextConfig from .su_rope import Phi3SuScaledRotaryEmbedding -from .language import TextConfig, LanguageModel from .vision import VisionConfig, VisionModel diff --git a/mlx_vlm/models/phi3_v/su_rope.py b/mlx_vlm/models/phi3_v/su_rope.py index 83dbdfe..d5f5943 100644 --- a/mlx_vlm/models/phi3_v/su_rope.py +++ b/mlx_vlm/models/phi3_v/su_rope.py @@ -1,4 +1,5 @@ import math + import mlx.core as mx diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py index f832e2d..075769d 100644 --- a/mlx_vlm/models/phi3_v/vision.py +++ b/mlx_vlm/models/phi3_v/vision.py @@ -1,8 +1,8 @@ import inspect import math from dataclasses import dataclass -from typing import Optional from types import SimpleNamespace +from typing import Optional import mlx.core as mx import mlx.nn as nn diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index a709ca8..58e0c7e 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -468,6 +468,7 @@ def test_multi_modality(self): args.vision_config.num_channels, (args.vision_config.image_size, args.vision_config.image_size), ) + def test_phi3_v(self): from mlx_vlm.models import phi3_v diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 42998a1..8853b34 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -160,7 +160,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: if model_type == "idefics2": config = AutoConfig.from_pretrained(model_path).to_dict() if model_type == "phi3_v": - config["vision_config"] = config['img_processor'] + config["vision_config"] = config["img_processor"] config["text_config"] = {} model_config = model_class.ModelConfig.from_dict(config) @@ -708,8 +708,8 @@ def prepare_inputs(image_processor, processor, image, prompt, image_token_index) pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) mask = mx.array(inputs["attention_mask"]) - if 'image_sizes' in inputs: - return input_ids, pixel_values, inputs['image_sizes'] + if "image_sizes" in inputs: + return input_ids, pixel_values, inputs["image_sizes"] return input_ids, pixel_values, mask From 9f39e1f2e49151e44a9beec99e8c4ed3a116b384 Mon Sep 17 00:00:00 2001 From: JosefAlbers <146810011+JosefAlbers@users.noreply.github.com> Date: Mon, 24 Jun 2024 01:50:52 +0000 Subject: [PATCH 5/9] precommit --- mlx_vlm/models/phi3_v/language.py | 1 + mlx_vlm/models/phi3_v/su_rope.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mlx_vlm/models/phi3_v/language.py b/mlx_vlm/models/phi3_v/language.py index 54903a7..763f188 100644 --- a/mlx_vlm/models/phi3_v/language.py +++ b/mlx_vlm/models/phi3_v/language.py @@ -1,6 +1,7 @@ import inspect from dataclasses import dataclass + @dataclass class TextConfig: @classmethod diff --git a/mlx_vlm/models/phi3_v/su_rope.py b/mlx_vlm/models/phi3_v/su_rope.py index 83dbdfe..d5f5943 100644 --- a/mlx_vlm/models/phi3_v/su_rope.py +++ b/mlx_vlm/models/phi3_v/su_rope.py @@ -1,4 +1,5 @@ import math + import mlx.core as mx From 20cf225c61fb24abf02efedcd664a6e8cd50b092 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 24 Jun 2024 16:13:30 +0200 Subject: [PATCH 6/9] remove debug print --- mlx_vlm/models/phi3_v/vision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx_vlm/models/phi3_v/vision.py b/mlx_vlm/models/phi3_v/vision.py index 075769d..e41a013 100644 --- a/mlx_vlm/models/phi3_v/vision.py +++ b/mlx_vlm/models/phi3_v/vision.py @@ -258,7 +258,6 @@ def __call__( return self.img_processor.vision_model( img_embeds, output_hidden_states=output_hidden_states ) - # print(0, txt_embeds.shape, img_embeds.shape, img_sizes.shape) img_embeds = mx.array(img_embeds) img_sizes = mx.array(img_sizes) B = img_embeds.shape[0] From 20f3143aede8b412bb3fa7938d140ca6fe4aab01 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 24 Jun 2024 17:04:03 +0200 Subject: [PATCH 7/9] add prompt format --- mlx_vlm/prompt_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 3350636..3d3f08b 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -18,6 +18,8 @@ def get_message_json(model_name, prompt): } elif model_name.lower() in ["llava-qwen2", "llava", "llava_next"]: message = {"role": "user", "content": f"\n{prompt}"} + elif model_name.lower() == "phi3_v": + message = {"role": "user", "content": f"<|image_1|>\n{prompt}"} elif model_name.lower() == "multi_modality": message = {"role": "user", "content": f"{prompt}"} elif model_name.lower() == "paligemma": From 91d38b8306066fa78c4146e530d84e1571d93f47 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 24 Jun 2024 17:59:18 +0200 Subject: [PATCH 8/9] add condition to fix quantisation --- mlx_vlm/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 8853b34..dd6d80a 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -542,11 +542,12 @@ def quantize_model( new_bias[:out_features] = module.bias module.bias = new_bias - quantized_config["vision_config"]["intermediate_size"] = ( - ((vision_intermediate_size // divisor) + 1) * divisor - if vision_intermediate_size % divisor != 0 - else vision_intermediate_size - ) + if "vision_config" in quantized_config: + quantized_config["vision_config"]["intermediate_size"] = ( + ((vision_intermediate_size // divisor) + 1) * divisor + if vision_intermediate_size % divisor != 0 + else vision_intermediate_size + ) nn.quantize(model, q_group_size, q_bits) quantized_config["quantization"] = {"group_size": q_group_size, "bits": q_bits} From 4c5048fbe91aa7bfc99ee4400d56f90590ac087e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 24 Jun 2024 17:59:30 +0200 Subject: [PATCH 9/9] bump version --- mlx_vlm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index 00ec2dc..9b36b86 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.0.9" +__version__ = "0.0.10"