diff --git a/mlx_vlm/models/idefics2/__init__.py b/mlx_vlm/models/idefics2/__init__.py index 8b6ddbd..8868208 100644 --- a/mlx_vlm/models/idefics2/__init__.py +++ b/mlx_vlm/models/idefics2/__init__.py @@ -1,5 +1,4 @@ from .idefics2 import ( - ImageProcessor, LanguageModel, Model, ModelConfig, diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 9ff118f..b78f6c2 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -3,27 +3,15 @@ import json import re from dataclasses import dataclass -from functools import partial, reduce from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np -import PIL from huggingface_hub import snapshot_download -from PIL import Image -from transformers import AutoConfig, Idefics2Config -from transformers.image_transforms import ( - convert_to_rgb, - normalize, - rescale, - resize, - to_channel_dimension_format, -) -from transformers.image_utils import to_numpy_array - -from ..base import BaseImageProcessor +from transformers import AutoConfig + from .language import LanguageModel, TextConfig from .vision import VisionConfig, VisionModel @@ -171,59 +159,18 @@ def __init__(self, config: ModelConfig): def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache=None): - mask = None - # if x.shape[1] > 1: - # mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) - # mask = mask.astype(x.dtype) - # mask = mx.expand_dims(mask, axis=0) - # mask = mx.repeat(mask, x.shape[0], axis=0) - if cache is None: cache = [None] * len(self.layers) h = mx.expand_dims(self.latents, axis=0) h = mx.repeat(h, x.shape[0], axis=0) + for e, layer in enumerate(self.layers): h, cache[e] = layer(h, x, mask=mask, cache=cache[e]) return self.norm(h), cache -class ImageProcessor(BaseImageProcessor): - def preprocess(self, images): - if isinstance(images, Image.Image): - images = [images] - else: - assert isinstance(images, list) - - transforms = [ - convert_to_rgb, - to_numpy_array, - partial( - resize, - size=self.size, - resample=self.resample, - data_format=self.data_format, - ), - partial(rescale, scale=self.rescale_factor, data_format=self.data_format), - partial( - normalize, - mean=self.image_mean, - std=self.image_std, - data_format=self.data_format, - ), - partial( - to_channel_dimension_format, - channel_dim=self.data_format, - input_channel_dim=self.data_format, - ), - ] - - images = reduce(lambda x, f: [*map(f, x)], transforms, images) - - return images - - class MLP(nn.Module): def __init__(self, dim, hidden_dim, output_size): super().__init__() @@ -246,9 +193,9 @@ def __init__(self, config: ModelConfig): self.perceiver_resampler = Idefics2PerceiverResampler(config) - def __call__(self, x: mx.array) -> mx.array: + def __call__(self, x: mx.array, mask=None) -> mx.array: x = self.modality_projection(x) - x = self.perceiver_resampler(x) + x = self.perceiver_resampler(x, mask=mask) return x @@ -258,25 +205,26 @@ def __init__(self, config: ModelConfig): self.config = config self.vision_model = VisionModel(config.vision_config) - self.text_model = LanguageModel(config.text_config) + self.language_model = LanguageModel(config.text_config) self.connector = Idefics2Connector(config) def get_input_embeddings( self, input_ids: Optional[mx.array] = None, pixel_values: Optional[mx.array] = None, + pixel_attention_mask: Optional[mx.array] = None, ): if pixel_values is None: - return self.text_model(input_ids) + return self.language_model(input_ids) - inputs_embeds = self.text_model.embed_tokens(input_ids) + inputs_embeds = self.language_model.embed_tokens(input_ids) *_, hidden_state = self.vision_model( - pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True ) image_features = hidden_state[-1].astype(pixel_values.dtype) - image_features, _ = self.connector(image_features) + image_features, _ = self.connector(image_features, mask=None) final_inputs_embeds = self._prepare_inputs_for_multimodal( image_features, inputs_embeds, input_ids @@ -302,9 +250,9 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id return inputs_embeds def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None): - input_embeddings = self.get_input_embeddings(input_ids, pixel_values) - logits, cache = self.text_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings + # input_embeddings = self.get_input_embeddings(input_ids, pixel_values) + logits, cache = self.language_model( + inputs=input_ids, cache=cache, inputs_embeds=None ) return logits, cache @@ -359,12 +307,21 @@ def sanitize(self, weights): ( f"{k.split('.', 1)[1]}" if re.match(r"^model\.", k) - else (f"text_model.{k}" if re.match(r"^lm_head\.", k) else k) + else (f"language_model.{k}" if re.match(r"^lm_head\.", k) else k) ): v for k, v in weights.items() } - return weights - + weights = { + ( + f"language_model.{k.split('.', 1)[1]}" + if re.match( + r"^text_model\.", + k, + ) + else k + ): v + for k, v in weights.items() + } -# Create a image classifier using torch + return weights