Skip to content

Commit

Permalink
remove image processor and rename text_model to LM
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Apr 30, 2024
1 parent ab63836 commit 57c5b75
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 72 deletions.
1 change: 0 additions & 1 deletion mlx_vlm/models/idefics2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .idefics2 import (
ImageProcessor,
LanguageModel,
Model,
ModelConfig,
Expand Down
99 changes: 28 additions & 71 deletions mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,15 @@
import json
import re
from dataclasses import dataclass
from functools import partial, reduce
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
import numpy as np
import PIL
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import AutoConfig, Idefics2Config
from transformers.image_transforms import (
convert_to_rgb,
normalize,
rescale,
resize,
to_channel_dimension_format,
)
from transformers.image_utils import to_numpy_array

from ..base import BaseImageProcessor
from transformers import AutoConfig

from .language import LanguageModel, TextConfig
from .vision import VisionConfig, VisionModel

Expand Down Expand Up @@ -171,59 +159,18 @@ def __init__(self, config: ModelConfig):

def __call__(self, x: mx.array, mask: Optional[mx.array] = None, cache=None):

mask = None
# if x.shape[1] > 1:
# mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
# mask = mask.astype(x.dtype)
# mask = mx.expand_dims(mask, axis=0)
# mask = mx.repeat(mask, x.shape[0], axis=0)

if cache is None:
cache = [None] * len(self.layers)

h = mx.expand_dims(self.latents, axis=0)
h = mx.repeat(h, x.shape[0], axis=0)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, x, mask=mask, cache=cache[e])

return self.norm(h), cache


class ImageProcessor(BaseImageProcessor):
def preprocess(self, images):
if isinstance(images, Image.Image):
images = [images]
else:
assert isinstance(images, list)

transforms = [
convert_to_rgb,
to_numpy_array,
partial(
resize,
size=self.size,
resample=self.resample,
data_format=self.data_format,
),
partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
partial(
normalize,
mean=self.image_mean,
std=self.image_std,
data_format=self.data_format,
),
partial(
to_channel_dimension_format,
channel_dim=self.data_format,
input_channel_dim=self.data_format,
),
]

images = reduce(lambda x, f: [*map(f, x)], transforms, images)

return images


class MLP(nn.Module):
def __init__(self, dim, hidden_dim, output_size):
super().__init__()
Expand All @@ -246,9 +193,9 @@ def __init__(self, config: ModelConfig):

self.perceiver_resampler = Idefics2PerceiverResampler(config)

def __call__(self, x: mx.array) -> mx.array:
def __call__(self, x: mx.array, mask=None) -> mx.array:
x = self.modality_projection(x)
x = self.perceiver_resampler(x)
x = self.perceiver_resampler(x, mask=mask)
return x


Expand All @@ -258,25 +205,26 @@ def __init__(self, config: ModelConfig):
self.config = config

self.vision_model = VisionModel(config.vision_config)
self.text_model = LanguageModel(config.text_config)
self.language_model = LanguageModel(config.text_config)
self.connector = Idefics2Connector(config)

def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
pixel_attention_mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.text_model(input_ids)
return self.language_model(input_ids)

inputs_embeds = self.text_model.embed_tokens(input_ids)
inputs_embeds = self.language_model.embed_tokens(input_ids)
*_, hidden_state = self.vision_model(
pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True
pixel_values[0].transpose(0, 2, 3, 1), output_hidden_states=True
)

image_features = hidden_state[-1].astype(pixel_values.dtype)

image_features, _ = self.connector(image_features)
image_features, _ = self.connector(image_features, mask=None)

final_inputs_embeds = self._prepare_inputs_for_multimodal(
image_features, inputs_embeds, input_ids
Expand All @@ -302,9 +250,9 @@ def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_id
return inputs_embeds

def __call__(self, input_ids: mx.array, pixel_values: mx.array, cache=None):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.text_model(
inputs=input_ids, cache=cache, inputs_embeds=input_embeddings
# input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
inputs=input_ids, cache=cache, inputs_embeds=None
)
return logits, cache

Expand Down Expand Up @@ -359,12 +307,21 @@ def sanitize(self, weights):
(
f"{k.split('.', 1)[1]}"
if re.match(r"^model\.", k)
else (f"text_model.{k}" if re.match(r"^lm_head\.", k) else k)
else (f"language_model.{k}" if re.match(r"^lm_head\.", k) else k)
): v
for k, v in weights.items()
}

return weights

weights = {
(
f"language_model.{k.split('.', 1)[1]}"
if re.match(
r"^text_model\.",
k,
)
else k
): v
for k, v in weights.items()
}

# Create a image classifier using torch
return weights

0 comments on commit 57c5b75

Please sign in to comment.