Skip to content

Commit

Permalink
add idefics quantisation
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Apr 30, 2024
1 parent 6fd6af4 commit ab63836
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,31 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module:
text_config = text_config.to_dict()
config["vision_config"] = vision_config["vision_config"]
config["text_config"] = text_config
if model_type == "idefics2":
config = AutoConfig.from_pretrained(model_path).to_dict()

model_config = model_class.ModelConfig.from_dict(config)
model_config.vision_config = model_class.VisionConfig.from_dict(
config["vision_config"]
)
model_config.text_config = model_class.TextConfig.from_dict(config["text_config"])
model_config.perceiver_config = model_class.PerceiverConfig.from_dict(
config["perceiver_config"]
)
model = model_class.Model(model_config)

if hasattr(model, "sanitize"):
weights = model.sanitize(weights)

weights = model_class.VisionModel(model_config.vision_config).sanitize(
weights=weights
)
weights = model_class.LanguageModel(model_config.text_config).sanitize(
weights=weights
)
if hasattr(model_class.VisionModel, "sanitize"):
weights = model_class.VisionModel(model_config.vision_config).sanitize(
weights=weights
)

if hasattr(model_class.LanguageModel, "sanitize"):
weights = model_class.LanguageModel(model_config.text_config).sanitize(
weights=weights
)

if (quantization := config.get("quantization", None)) is not None:
# Handle legacy models which may not have everything quantized
Expand Down Expand Up @@ -447,7 +455,7 @@ def quantize_model(
quantized_config = copy.deepcopy(config)
vision_intermediate_size = model.config.vision_config.intermediate_size
class_predicate = lambda path, m: isinstance(m, nn.Linear) and (
path.split(".")[0] != "vision_tower"
path.split(".")[0] not in ["vision_model", "vision_tower"]
if any(vision_intermediate_size % size != 0 for size in [64, 128])
else not isinstance(m, nn.Embedding)
)
Expand Down Expand Up @@ -590,6 +598,8 @@ def load_image(image_source):


def prepare_inputs(image_processor, processor, image, prompt):
from transformers.image_utils import load_image

if isinstance(image, str):
image = load_image(image)

Expand All @@ -602,7 +612,6 @@ def prepare_inputs(image_processor, processor, image, prompt):
inputs = processor(prompt, image, return_tensors="np")
pixel_values = mx.array(inputs["pixel_values"])
input_ids = mx.array(inputs["input_ids"])

return input_ids, pixel_values


Expand Down

0 comments on commit ab63836

Please sign in to comment.