diff --git a/.gitignore b/.gitignore index 405a7e8..0e197bc 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ var/ .installed.cfg *.egg .DS_Store +*.log diff --git a/README.md b/README.md index c4c271d..d411678 100644 --- a/README.md +++ b/README.md @@ -1,43 +1,119 @@ # MLX-VLM -MLX-VLM a package for running Vision LLMs on your Mac using MLX. +MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX. +## Table of Contents +- [Installation](#installation) +- [Usage](#usage) + - [Command Line Interface (CLI)](#command-line-interface-cli) + - [Chat UI with Gradio](#chat-ui-with-gradio) + - [Python Script](#python-script) +- [Multi-Image Chat Support](#multi-image-chat-support) + - [Supported Models](#supported-models) + - [Usage Examples](#usage-examples) +- [Fine-tuning](#fine-tuning) -## Get started +## Installation -The easiest way to get started is to install the `mlx-vlm` package: - -**With `pip`**: +The easiest way to get started is to install the `mlx-vlm` package using pip: ```sh pip install mlx-vlm ``` -## Inference +## Usage + +### Command Line Interface (CLI) + +Generate output from a model using the CLI: -**CLI** ```sh -python -m mlx_vlm.generate --model qnguyen3/nanoLLaVA --max-tokens 100 --temp 0.0 +python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temp 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg ``` -**Chat UI with Gradio** +### Chat UI with Gradio + +Launch a chat interface using Gradio: + ```sh -python -m mlx_vlm.chat_ui --model qnguyen3/nanoLLaVA +python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit ``` -**Script** +### Python Script + +Here's an example of how to use MLX-VLM in a Python script: + ```python import mlx.core as mx from mlx_vlm import load, generate +from mlx_vlm.prompt_utils import apply_chat_template + +# Load the model +model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" +model, processor = load(model_path) + +# Prepare input +image = ["http://images.cocodataset.org/val2017/000000039769.jpg"] +prompt = "Describe this image." + +# Apply chat template +formatted_prompt = apply_chat_template( + processor, config, prompt, num_images=len(image) +) + +# Generate output +output = generate(model, processor, image, formatted_prompt, verbose=False) +print(output) +``` + +## Multi-Image Chat Support + +MLX-VLM supports analyzing multiple images simultaneously with select models. This feature enables more complex visual reasoning tasks and comprehensive analysis across multiple images in a single conversation. + +### Supported Models + +The following models support multi-image chat: + +1. Idefics 2 +2. LLaVA (Interleave) +3. Qwen2-VL +4. Phi3-Vision +5. Pixtral + +### Usage Examples + +#### Python Script + +```python +from mlx_vlm import load, generate +from mlx_vlm.prompt_utils import apply_chat_template -model_path = "mlx-community/llava-1.5-7b-4bit" +model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit" model, processor = load(model_path) -prompt = processor.tokenizer.apply_chat_template( - [{"role": "user", "content": f"\nWhat are these?"}], - tokenize=False, - add_generation_prompt=True, +images = ["path/to/image1.jpg", "path/to/image2.jpg"] +prompt = "Compare these two images." + +formatted_prompt = apply_chat_template( + processor, config, prompt, num_images=len(images) ) -output = generate(model, processor, "http://images.cocodataset.org/val2017/000000039769.jpg", prompt, verbose=False) +output = generate(model, processor, images, formatted_prompt, verbose=False) +print(output) ``` + +#### Command Line + +```sh +python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg +``` + +These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks. + +# Fine-tuning + +MLX-VLM supports fine-tuning models with LoRA and QLoRA. + +## LoRA & QLoRA + +To learn more about LoRA, please refer to the [LoRA.md](./mlx_vlm/LoRA.md) file. diff --git a/mlx_vlm/LORA.MD b/mlx_vlm/LORA.MD new file mode 100644 index 0000000..4833846 --- /dev/null +++ b/mlx_vlm/LORA.MD @@ -0,0 +1,77 @@ +# LoRA Training Script + +## Overview + +`lora.py` is a Python script for fine-tuning a vision language models (VLMs) using Low-Rank Adaptation (LoRA or QLoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. + +## Requirements + +- Python 3.7+ +- Required Python packages: `mlx-vlm`, `numpy`, `transformers`, `datasets`, `PIL` + +## Supported Models +- Qwen2 +- LLaVA (except for LLaVA-Next) +- Pixtral +- Idefics 2 +- Deepseek-VL +- Paligemma + +## Coming Soon +- LLaVA-Next +- Phi3_vision + +## Usage + +To use the script, run it from the command line with the desired arguments: + +``` +python lora.py --dataset /path/to/your/dataset [other options] +``` + +## Dataset format + +The dataset should be a Hugging Face dataset with a `images` column and a `messages` column. + +``` +{ + "images": ..., + "messages": ..., +} +``` + +Support for other formats and column names will be added soon. + +## Arguments + +The script accepts the following command-line arguments: + +- `--model_path`: Path to the pre-trained model (default: "mlx-community/Qwen2-VL-2B-Instruct-bf16") +- `--dataset`: Path to your dataset (required) +- `--learning_rate`: Learning rate for the optimizer (default: 1e-4) +- `--batch_size`: Batch size for training (default: 1) +- `--epochs`: Number of epochs to train (default: 1) +- `--steps`: Number of steps per epoch (default: 0) +- `--print_every`: Print loss every n steps (default: 10) +- `--output_path`: Path to save the trained adapter (default: "adapters.safetensors") + +## Example + +Here's an example of how to run the script with custom parameters: + +``` +python lora.py --dataset /path/to/your/dataset --model_path /path/to/your/model --epochs 2 --batch_size 4 --learning_rate 5e-5 +``` + +## Output + +The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path. + +## Note + +If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model_path` argument (i.e. `mlx-community/Qwen2-VL-2B-Instruct-4bit`). +Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. + +## Contributing + +Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. diff --git a/mlx_vlm/__init__.py b/mlx_vlm/__init__.py index 03cb6f1..63e9873 100644 --- a/mlx_vlm/__init__.py +++ b/mlx_vlm/__init__.py @@ -1,2 +1,3 @@ -from .utils import convert, generate, load +from .prompt_utils import apply_chat_template, get_message_json +from .utils import convert, generate, load, prepare_inputs from .version import __version__ diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 3e95cc2..3e85b30 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -1,13 +1,11 @@ import argparse import codecs -import mlx.core as mx - from .prompt_utils import apply_chat_template from .utils import generate, get_model_path, load, load_config, load_image_processor DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit" -DEFAULT_IMAGE = "http://images.cocodataset.org/val2017/000000039769.jpg" +DEFAULT_IMAGE = ["http://images.cocodataset.org/val2017/000000039769.jpg"] DEFAULT_PROMPT = "What are these?" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.5 @@ -25,9 +23,16 @@ def parse_arguments(): default=DEFAULT_MODEL_PATH, help="The path to the local model directory or Hugging Face repo.", ) + parser.add_argument( + "--adapter-path", + type=str, + default=None, + help="The path to the adapter weights.", + ) parser.add_argument( "--image", type=str, + nargs="+", default=DEFAULT_IMAGE, help="URL or path of the image to process.", ) @@ -50,22 +55,28 @@ def parse_arguments(): return parser.parse_args() -def get_model_and_processors(model_path): +def get_model_and_processors(model_path, adapter_path): model_path = get_model_path(model_path) config = load_config(model_path) - model, processor = load(model_path, {"trust_remote_code": True}) + model, processor = load( + model_path, {"trust_remote_code": True}, adapter_path=adapter_path + ) image_processor = load_image_processor(model_path) return model, processor, image_processor, config def main(): args = parse_arguments() - model, processor, image_processor, config = get_model_and_processors(args.model) + if isinstance(args.image, str): + args.image = [args.image] + + model, processor, image_processor, config = get_model_and_processors( + args.model, args.adapter_path + ) prompt = codecs.decode(args.prompt, "unicode_escape") - if model.config.model_type != "paligemma": - prompt = apply_chat_template(processor, config, prompt) + prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image)) output = generate( model, diff --git a/mlx_vlm/lora.py b/mlx_vlm/lora.py new file mode 100644 index 0000000..99822ed --- /dev/null +++ b/mlx_vlm/lora.py @@ -0,0 +1,169 @@ +import argparse +import json +import logging + +import mlx.optimizers as optim +from datasets import load_dataset +from tqdm import tqdm + +from .prompt_utils import apply_chat_template +from .trainer import Dataset, Trainer, save_adapter +from .trainer.utils import find_all_linear_names, get_peft_model +from .utils import load, load_image_processor + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def custom_print(*args, **kwargs): + tqdm.write(" ".join(map(str, args)), **kwargs) + + +def main(args): + logger.info(f"\033[32mLoading model from {args.model_path}\033[0m") + model, processor = load( + args.model_path, processor_config={"trust_remote_code": True} + ) + config = model.config.__dict__ + image_processor = load_image_processor(args.model_path) + + logger.info(f"\033[32mLoading dataset from {args.dataset}\033[0m") + dataset = load_dataset(args.dataset, split=args.split) + + if "messages" not in dataset.column_names: + raise ValueError("Dataset must have a 'messages' column") + if "images" not in dataset.column_names: + raise ValueError("Dataset must have an 'images' column") + + if args.apply_chat_template: + logger.info(f"\033[32mApplying chat template to the dataset\033[0m") + + def process_data(examples): + if config["model_type"] == "pixtral": + conversations = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + examples["messages"] = [ + json.dumps(item, ensure_ascii=False) for item in conversations + ] + else: + examples["messages"] = apply_chat_template( + config=config, + processor=processor, + prompt=examples["messages"], + return_messages=True, + ) + return examples + + dataset = dataset.map(process_data) + + dataset = Dataset( + dataset, + config, + processor, + image_processor=image_processor, + ) + + logger.info(f"\033[32mSetting up LoRA\033[0m") + list_of_modules = find_all_linear_names(model.language_model) + model = get_peft_model( + model, + list_of_modules, + rank=args.lora_rank, + alpha=args.lora_alpha, + dropout=args.lora_dropout, + ) + + logger.info(f"\033[32mSetting up optimizer\033[0m") + optimizer = optim.Adam(learning_rate=args.learning_rate) + + logger.info(f"\033[32mSetting up trainer\033[0m") + trainer = Trainer(model, optimizer) + + model.train() + + # Training loop + logger.info(f"\033[32mTraining model\033[0m") + for epoch in range(args.epochs): + if args.steps == 0: + args.steps = len(dataset) // args.batch_size + + progress_bar = tqdm(range(args.steps), position=0, leave=True) + for i in progress_bar: + loss = trainer.train_step( + dataset[i * args.batch_size : (i + 1) * args.batch_size] + ) + # Update progress bar + progress_bar.update(1) + progress_bar.set_postfix( + {"Epoch": epoch, "Step": i, "Loss": f"{loss.item():.4f}"} + ) + + if i % args.print_every == 0: + # Log additional information + custom_print( + { + "Epoch": epoch, + "Step": i, + "Loss": f"{loss.item():.4f}", + } + ) + + # Save the adapter + save_adapter(model, args.output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Train NanoLLaVA model") + parser.add_argument( + "--model-path", + type=str, + default="mlx-community/Qwen2-VL-2B-Instruct-bf16", + help="Path to the pre-trained model", + ) + parser.add_argument( + "--dataset", type=str, required=True, help="Path to the dataset" + ) + parser.add_argument( + "--split", type=str, default="train", help="Split to use for training" + ) + parser.add_argument( + "--apply-chat-template", + action="store_false", + help="Apply chat template to the dataset", + ) + parser.add_argument( + "--learning-rate", + type=float, + default=1e-4, + help="Learning rate for the optimizer", + ) + parser.add_argument( + "--batch-size", type=int, default=1, help="Batch size for training" + ) + parser.add_argument( + "--epochs", type=int, default=1, help="Number of epochs to train" + ) + parser.add_argument( + "--steps", type=int, default=0, help="Number of steps per epoch" + ) + parser.add_argument( + "--print-every", type=int, default=10, help="Print loss every n steps" + ) + parser.add_argument( + "--lora-alpha", type=int, default=0.1, help="LoRA alpha parameter" + ) + parser.add_argument("--lora-rank", type=int, default=10, help="LoRA rank") + parser.add_argument("--lora-dropout", type=float, default=0.1, help="LoRA dropout") + parser.add_argument( + "--output-path", + type=str, + default="adapters", + help="Path to save the trained adapter", + ) + + args = parser.parse_args() + main(args) diff --git a/mlx_vlm/models/idefics2/idefics2.py b/mlx_vlm/models/idefics2/idefics2.py index 1c78365..52085dd 100644 --- a/mlx_vlm/models/idefics2/idefics2.py +++ b/mlx_vlm/models/idefics2/idefics2.py @@ -197,6 +197,7 @@ def __call__(self, x: mx.array, mask=None) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config diff --git a/mlx_vlm/models/llava/language.py b/mlx_vlm/models/llava/language.py index 732b636..a7f11b4 100644 --- a/mlx_vlm/models/llava/language.py +++ b/mlx_vlm/models/llava/language.py @@ -21,6 +21,7 @@ class TextConfig: rope_theta: float = 10000 rope_traditional: bool = False rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = False @classmethod def from_dict(cls, params): @@ -58,9 +59,14 @@ def __init__(self, config: TextConfig): head_dim = config.hidden_size // n_heads self.scale = head_dim**-0.5 - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + if config.model_type == "qwen2": + attention_bias = True + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) rope_scale = ( @@ -184,12 +190,13 @@ def __init__(self, config: TextConfig): super().__init__() self.config = config self.model_type = config.model_type - if self.model_type != "llama": + if self.model_type not in ["llama", "qwen2"]: raise ValueError( f"Model type {self.model_type} not supported. Currently only 'llama' is supported" ) self.model = Llama(config) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + if not config.tie_word_embeddings: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__( self, @@ -199,7 +206,11 @@ def __call__( mask: Optional[mx.array] = None, ): out = self.model(inputs, cache, inputs_embeds) - return self.lm_head(out) + if self.config.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out @staticmethod def sanitize(weights): diff --git a/mlx_vlm/models/llava/llava.py b/mlx_vlm/models/llava/llava.py index 39aae4a..39298a9 100644 --- a/mlx_vlm/models/llava/llava.py +++ b/mlx_vlm/models/llava/llava.py @@ -56,6 +56,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/llava/vision.py b/mlx_vlm/models/llava/vision.py index 5a5ec42..31c2734 100644 --- a/mlx_vlm/models/llava/vision.py +++ b/mlx_vlm/models/llava/vision.py @@ -151,31 +151,44 @@ def __init__(self, config: VisionConfig): self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = mx.zeros((config.hidden_size,)) + if config.model_type == "siglip_vision_model": + bias = True + self.class_embedding = None + else: + bias = False + self.class_embedding = mx.zeros((config.hidden_size,)) self.patch_embedding = nn.Conv2d( in_channels=config.num_channels, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, - bias=False, + bias=bias, ) self.num_patches = (self.image_size // self.patch_size) ** 2 - self.num_positions = self.num_patches + 1 + self.num_positions = ( + self.num_patches + 1 + if config.model_type == "clip_vision_model" + else self.num_patches + ) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) - embed_dim = patch_embeddings.shape[-1] - cls_embeddings = mx.broadcast_to( - self.class_embedding, (batch_size, 1, embed_dim) - ) + if self.config.model_type == "siglip_vision_model": + embeddings = patch_embeddings + else: + embed_dim = patch_embeddings.shape[-1] + cls_embeddings = mx.broadcast_to( + self.class_embedding, (batch_size, 1, embed_dim) + ) + embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) - embeddings = mx.concatenate((cls_embeddings, patch_embeddings), axis=1) embeddings += self.position_embedding(position_ids) return embeddings @@ -183,8 +196,10 @@ def __call__(self, x: mx.array) -> mx.array: class ClipVisionModel(nn.Module): def __init__(self, config: VisionConfig): super().__init__() + self.config = config self.embeddings = VisionEmbeddings(config) - self.pre_layrnorm = nn.LayerNorm(config.hidden_size) + if self.config.model_type == "clip_vision_model": + self.pre_layrnorm = nn.LayerNorm(config.hidden_size) self.encoder = Encoder(config) self.post_layernorm = nn.LayerNorm(config.hidden_size) @@ -194,7 +209,8 @@ def __call__( output_hidden_states: Optional[bool] = None, ) -> mx.array: x = self.embeddings(x) - x = self.pre_layrnorm(x) + if self.config.model_type == "clip_vision_model": + x = self.pre_layrnorm(x) encoder_states = (x,) if output_hidden_states else None @@ -212,7 +228,7 @@ def __init__(self, config: VisionConfig): super().__init__() self.model_type = config.model_type - if self.model_type != "clip_vision_model": + if self.model_type not in ["clip_vision_model", "siglip_vision_model"]: raise ValueError(f"Unsupported model type: {self.model_type}") self.vision_model = ClipVisionModel(config) diff --git a/mlx_vlm/models/llava_bunny/language.py b/mlx_vlm/models/llava_bunny/language.py index 153a650..a5a4fb0 100644 --- a/mlx_vlm/models/llava_bunny/language.py +++ b/mlx_vlm/models/llava_bunny/language.py @@ -167,6 +167,7 @@ def __call__( inputs: mx.array, cache=None, inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, ): # for passing merged input embeddings if inputs_embeds is None: @@ -199,7 +200,7 @@ def __call__( inputs_embeds: Optional[mx.array] = None, mask: Optional[mx.array] = None, ): - out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds) + out = self.model(inputs, cache=cache, inputs_embeds=inputs_embeds, mask=None) return out def sanitize(self, weights): diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index 0e145a2..bbbb5cd 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -118,6 +118,7 @@ def __call__( class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config @@ -151,31 +152,28 @@ def get_input_embeddings( def _prepare_inputs_for_multimodal(self, image_features, inputs_embeds, input_ids): image_token_index = self.config.image_token_index - num_images, num_image_patches, embed_dim = image_features.shape - - # Positions of tokens in input_ids, assuming batch size is 1 - image_positions = np.where(input_ids[0] == image_token_index)[0].tolist() + batch_size, seq_length, embed_dim = inputs_embeds.shape + num_images, num_image_patches, _ = image_features.shape - if len(image_positions) != num_images: - raise ValueError( - f"The number of image tokens ({len(image_positions)}) does not " - f" match the number of image inputs ({num_images})." - ) + # Positions of tokens in input_ids for each batch + image_positions = mx.argmax(input_ids == image_token_index, axis=1) - text_segments = [] - start_idx = 0 + final_embeddings = [] + for b in range(batch_size): + text_segments = [] + start_idx = 0 + position = int(image_positions[b].item()) - for position in image_positions: - text_segments.append(inputs_embeds[:, start_idx:position]) - start_idx = position + 1 + text_segments.append(inputs_embeds[b : b + 1, start_idx:position]) + text_segments.append(image_features[b : b + 1]) + text_segments.append(inputs_embeds[b : b + 1, position + 1 :]) - image_embeddings = mx.split(image_features, image_features.shape[0]) - final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] - final_embeddings += [inputs_embeds[:, start_idx:]] + batch_embeddings = mx.concatenate(text_segments, axis=1) + final_embeddings.append(batch_embeddings) # Create a final embedding of shape - # (1, num_image_patches*num_images + sequence_len, embed_dim) - return mx.concatenate(final_embeddings, axis=1) + # (batch_size, num_image_patches + sequence_len, embed_dim) + return mx.concatenate(final_embeddings, axis=0) def __call__( self, @@ -187,7 +185,7 @@ def __call__( ): input_embeddings = self.get_input_embeddings(input_ids, pixel_values) logits = self.language_model( - inputs=input_ids, cache=cache, inputs_embeds=input_embeddings + inputs=input_ids, cache=cache, inputs_embeds=input_embeddings, mask=mask ) return logits diff --git a/mlx_vlm/models/llava_bunny/vision.py b/mlx_vlm/models/llava_bunny/vision.py index 636cbf7..df3e3c5 100644 --- a/mlx_vlm/models/llava_bunny/vision.py +++ b/mlx_vlm/models/llava_bunny/vision.py @@ -207,9 +207,9 @@ def __call__(self, x: mx.array) -> mx.array: batch_size = x.shape[0] patch_embeddings = self.patch_embedding(x) patch_embeddings = mx.flatten(patch_embeddings, start_axis=1, end_axis=2) - self.position_ids = mx.array(np.arange(self.num_positions)[None, :]) + position_ids = mx.array(np.arange(self.num_positions)[None, :]) embeddings = patch_embeddings - embeddings += self.position_embedding(self.position_ids) + embeddings += self.position_embedding(position_ids) return embeddings diff --git a/mlx_vlm/models/llava_next/llava_next.py b/mlx_vlm/models/llava_next/llava_next.py index 878d7ca..29abea1 100644 --- a/mlx_vlm/models/llava_next/llava_next.py +++ b/mlx_vlm/models/llava_next/llava_next.py @@ -56,6 +56,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index 52a0bc9..c1d9df9 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -238,6 +238,7 @@ def __call__(self, x: Union[mx.array, Tuple]) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_model = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) diff --git a/mlx_vlm/models/paligemma/paligemma.py b/mlx_vlm/models/paligemma/paligemma.py index 7400738..fabdac5 100644 --- a/mlx_vlm/models/paligemma/paligemma.py +++ b/mlx_vlm/models/paligemma/paligemma.py @@ -52,6 +52,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.model_type = config.model_type self.config = config diff --git a/mlx_vlm/models/phi3_v/phi3_v.py b/mlx_vlm/models/phi3_v/phi3_v.py index 770f001..01285f5 100644 --- a/mlx_vlm/models/phi3_v/phi3_v.py +++ b/mlx_vlm/models/phi3_v/phi3_v.py @@ -209,9 +209,10 @@ def __call__( pixel_values=None, mask=None, cache=None, + image_sizes=None, **kwargs, ): - out = self.model(inputs, pixel_values, mask, cache) + out = self.model(inputs, pixel_values, image_sizes, cache) return self.lm_head(out).astype(self.lm_head.weight.dtype) @property diff --git a/mlx_vlm/models/pixtral/pixtral.py b/mlx_vlm/models/pixtral/pixtral.py index b49397b..fb51a84 100644 --- a/mlx_vlm/models/pixtral/pixtral.py +++ b/mlx_vlm/models/pixtral/pixtral.py @@ -56,6 +56,7 @@ def __call__(self, x: mx.array) -> mx.array: class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) @@ -74,24 +75,33 @@ def get_input_embeddings( # Get the input embeddings from the language model inputs_embeds = self.language_model.model.embed_tokens(input_ids) + # Get number of images + num_images = len(pixel_values[0]) + # Get the ouptut hidden states from the vision model + if isinstance(pixel_values, list): + if input_ids.shape[0] == 1: # Batch size is 1 + pixel_values = mx.concatenate( + [mx.array(pv) for pv in pixel_values[0]], axis=1 + )[None, ...] + else: # Batch size is greater than 1 + pixel_values = mx.concatenate( + [mx.array(pv) for pv in pixel_values], axis=0 + ) + if pixel_values.ndim == 3: + pixel_values = pixel_values[None, ...] + + pixel_values = mx.split(pixel_values, num_images, axis=2) + + # Pass pixel_values as list of images, as each image is individually run through conv2d and position encoding + # Reference code from transformers: https://github.com/huggingface/transformers/blob/main/src/transformers/models/pixtral/modeling_pixtral.py#L479C9-L479C21 + # and mistral_inference: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py#L85 *_, hidden_states = self.vision_tower( - pixel_values.transpose(0, 2, 3, 1), output_hidden_states=True + [pv.transpose(0, 2, 3, 1) for pv in pixel_values], output_hidden_states=True ) - # Select the hidden states from the desired layer selected_image_feature = hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - "Unexpected feature selection strategy: " - f"{self.vision_feature_select_strategy}" - ) - # Pass image features through the multi-modal projector image_features = self.multi_modal_projector(selected_image_feature) @@ -117,7 +127,8 @@ def _merge_input_ids_with_image_features( text_segments.append(inputs_embeds[:, start_idx:position]) start_idx = position + 1 - image_embeddings = mx.split(image_features, image_features.shape[0]) + # Split image features into separate embeddings for each image + image_embeddings = mx.split(image_features, num_image_patches, axis=1) final_embeddings = [v for p in zip(text_segments, image_embeddings) for v in p] final_embeddings += [inputs_embeds[:, start_idx:]] @@ -181,7 +192,7 @@ def from_pretrained(path_or_hf_repo: str): def sanitize(self, weights): def transform_key(key): - if "vision_tower" in key: + if "vision_tower" in key and "vision_model" not in key: if "transformer" in key: key = key.replace("vision_tower", "vision_tower.vision_model") if "patch_conv" in key: diff --git a/mlx_vlm/models/pixtral/vision.py b/mlx_vlm/models/pixtral/vision.py index 2db77f4..1f015ba 100644 --- a/mlx_vlm/models/pixtral/vision.py +++ b/mlx_vlm/models/pixtral/vision.py @@ -1,6 +1,6 @@ import inspect from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import mlx.core as mx import mlx.nn as nn @@ -253,11 +253,11 @@ def __init__(self, config: VisionConfig): def __call__( self, - x: mx.array, + x: List[mx.array], output_hidden_states: Optional[bool] = None, ) -> mx.array: - B, H, W, C = x.shape - patch_embeds_list = [self.patch_conv(img[None, :]) for img in x] + B, H, W, C = x[0].shape + patch_embeds_list = [self.patch_conv(img) for img in x] patch_embeds = mx.concatenate( [p.reshape(B, -1, p.shape[-1]) for p in patch_embeds_list], axis=1 @@ -299,7 +299,7 @@ def __init__(self, config: VisionConfig): self.vision_model = PixtralVisionModel(config) def __call__( - self, x: mx.array, output_hidden_states: Optional[bool] = None + self, x: List[mx.array], output_hidden_states: Optional[bool] = None ) -> mx.array: return self.vision_model(x, output_hidden_states) diff --git a/mlx_vlm/models/qwen2_vl/language.py b/mlx_vlm/models/qwen2_vl/language.py index cebcd22..ef2518d 100644 --- a/mlx_vlm/models/qwen2_vl/language.py +++ b/mlx_vlm/models/qwen2_vl/language.py @@ -282,6 +282,10 @@ def __init__(self, args: TextConfig): self.args = args self.model_type = args.model_type self.model = Qwen2Model(args) + + if args.model_type != "qwen2_vl": + raise ValueError(f"Unsupported model type: {args.model_type}") + if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) diff --git a/mlx_vlm/models/qwen2_vl/qwen2_vl.py b/mlx_vlm/models/qwen2_vl/qwen2_vl.py index 8746eee..98f8b20 100644 --- a/mlx_vlm/models/qwen2_vl/qwen2_vl.py +++ b/mlx_vlm/models/qwen2_vl/qwen2_vl.py @@ -18,7 +18,6 @@ class ModelConfig: text_config: TextConfig vision_config: VisionConfig - rope_scaling: dict model_type: str ignore_index: int = -100 image_token_index: int = 151655 @@ -39,6 +38,7 @@ def from_dict(cls, params): class Model(nn.Module): def __init__(self, config: ModelConfig): + super().__init__() self.config = config self.vision_tower = VisionModel(config.vision_config) self.language_model = LanguageModel(config.text_config) @@ -61,6 +61,9 @@ def get_input_embeddings( pixel_values, image_grid_thw, output_hidden_states=False ) + if hidden_states.ndim == 2: + hidden_states = hidden_states[None, :, :] + # Insert special image tokens in the input_ids final_inputs_embeds = self._merge_input_ids_with_image_features( hidden_states, inputs_embeds, input_ids diff --git a/mlx_vlm/models/qwen2_vl/vision.py b/mlx_vlm/models/qwen2_vl/vision.py index f07f880..7b78447 100644 --- a/mlx_vlm/models/qwen2_vl/vision.py +++ b/mlx_vlm/models/qwen2_vl/vision.py @@ -19,7 +19,7 @@ class VisionConfig: mlp_ratio: float = 4.0 in_channels: int = 3 layer_norm_eps: float = 1e-6 - spatial_patch_size = 14 + spatial_patch_size: int = 14 spatial_merge_size: int = 2 temporal_patch_size: int = 2 @@ -320,10 +320,22 @@ def __call__( grid_thw: mx.array, output_hidden_states: Optional[bool] = None, ) -> mx.array: + hidden_states = self.patch_embed(hidden_states) rotary_pos_emb = self.rot_pos_emb(grid_thw) - cu_seqlens = mx.repeat(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]) + # Assuming grid_thw has shape (batch_size, 3) + batch_size = grid_thw.shape[0] + + # Calculate cu_seqlens for each item in the batch + cu_seqlens = [] + for i in range(batch_size): + seq_len = grid_thw[i, 1] * grid_thw[i, 2] + cu_seqlens.append(mx.repeat(seq_len, grid_thw[i, 0])) + + # Concatenate the cu_seqlens for all items in the batch + cu_seqlens = mx.concatenate(cu_seqlens) + cu_seqlens = mx.cumsum(cu_seqlens.astype(mx.int32)) cu_seqlens = mx.pad(cu_seqlens, (1, 0), mode="constant", constant_values=0) diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index b47a9a1..1289fed 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -1,56 +1,151 @@ -def get_message_json(model_name, prompt): +def get_message_json( + model_name, prompt, role="user", skip_image_token=False, num_images=1 +): """ Get the appropriate JSON message based on the specified model. Args: model_name (str): The model for which to generate the message. prompt (str): The text prompt to be included in the message. - *args: Additional positional arguments (unused). - **kwargs: Additional keyword arguments (unused). + role (str): The role of the message (default: "user"). + skip_image_token (bool): Whether to skip adding image tokens (default: False). + num_images (int): Number of image tokens to add (default: 1). Returns: dict: A dictionary representing the JSON message for the specified model. """ - if model_name.lower() in ["idefics2", "qwen2_vl", "llava"]: - message = { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "text": prompt}], - } - - elif model_name.lower() in ["llava-qwen2", "llava_next", "bunny-llama"]: - message = {"role": "user", "content": f"\n{prompt}"} - elif model_name.lower() == "phi3_v": - message = {"role": "user", "content": f"<|image_1|>\n{prompt}"} - elif model_name.lower() == "multi_modality": - message = {"role": "user", "content": f"{prompt}"} - elif model_name.lower() == "paligemma": - message = prompt - elif model_name.lower() == "pixtral": - message = { - "role": "user", - "content": [{"type": "image"}, {"type": "text", "content": prompt}], - } + model_name = model_name.lower() + + def create_message(role, prompt): + return {"role": role, "content": prompt} + + def add_image_tokens(message, token_format): + if role == "user" and not skip_image_token: + if isinstance(message["content"], list): + message["content"].extend([{"type": "image"}] * num_images) + else: + if model_name == "phi3_v": + message["content"] = f"{token_format}{message['content']}" + else: + message["content"] = ( + f"{token_format * num_images}{message['content']}" + ) + if role == "assistant" and model_name == "pixtral": + message["content"] = message["content"][0]["content"] + return message + + message_formats = { + "message_list_with_image": lambda: add_image_tokens( + {"role": role, "content": [{"type": "text", "text": prompt}]}, "" + ), + "message_list_with_image_type": lambda: add_image_tokens( + {"role": role, "content": [{"type": "text", "content": prompt}]}, "" + ), + "message_with_image_token": lambda: add_image_tokens( + create_message(role, prompt), "" + ), + "message_with_image_token_new_line": lambda: add_image_tokens( + create_message(role, prompt), "\n" + ), + "message_with_numbered_image_tokens": lambda: add_image_tokens( + create_message(role, prompt), + " ".join([f"<|image_{i+1}|>" for i in range(num_images)]), + ), + "prompt_only": lambda: prompt, + } + + model_to_format = { + "idefics2": "message_list_with_image", + "qwen2_vl": "message_list_with_image", + "llava": "message_list_with_image", + "llava_next": "message_list_with_image", + "llava-qwen2": "message_with_image_token_new_line", + "bunny-llama": "message_with_image_token_new_line", + "phi3_v": "message_with_numbered_image_tokens", + "multi_modality": "message_with_image_token", + "pixtral": "message_list_with_image_type", + "paligemma": "prompt_only", + } + + if num_images > 1 and model_name in [ + "llava_next", + "llava-qwen2", + "bunny-llama", + "paligemma", + "multi_modality", + ]: + raise ValueError( + f"Model {model_name} does not support multi-image chat. Please only use 1 image." + ) + + format_key = model_to_format.get(model_name) + if format_key: + return message_formats[format_key]() else: raise ValueError(f"Unsupported model: {model_name}") - return message +def apply_chat_template( + processor, + config, + prompt, + add_generation_prompt=True, + return_messages=False, + num_images=1, +): + config = config if isinstance(config, dict) else config.__dict__ + + def process_single_prompt(p, is_first=True): + if isinstance(p, str): + return get_message_json( + config["model_type"], + p, + skip_image_token=not is_first, + num_images=num_images, + ) + elif isinstance(p, dict) and "role" in p: + return get_message_json( + config["model_type"], + p["content"], + p["role"], + skip_image_token=not is_first, + num_images=num_images, + ) + else: + raise ValueError("Invalid prompt type") + + messages = [] + if isinstance(prompt, list): + if isinstance(prompt[0], dict): + messages = [process_single_prompt(p, i == 0) for i, p in enumerate(prompt)] + else: + messages = [ + msg + for prompts in prompt + for i, p in enumerate(prompts) + for msg in [process_single_prompt(p, i == 0)] + ] + else: + messages = [process_single_prompt(prompt)] + + if return_messages: + return messages -def apply_chat_template(processor, config, prompt): - message = get_message_json(config["model_type"], prompt) + if config["model_type"] == "paligemma": + return messages[-1] if "chat_template" in processor.__dict__.keys(): return processor.apply_chat_template( - [message], + messages, tokenize=False, - add_generation_prompt=True, + add_generation_prompt=add_generation_prompt, ) elif "tokenizer" in processor.__dict__.keys(): return processor.tokenizer.apply_chat_template( - [message], + messages, tokenize=False, - add_generation_prompt=True, + add_generation_prompt=add_generation_prompt, ) else: diff --git a/mlx_vlm/tests/test_models.py b/mlx_vlm/tests/test_models.py index 56c6366..79b13ba 100644 --- a/mlx_vlm/tests/test_models.py +++ b/mlx_vlm/tests/test_models.py @@ -62,7 +62,6 @@ def vision_test_runner( shape=(batch_size, image_size[0], image_size[1], num_channels) ) - # Perform a forward pass hidden_states = vision_tower(input_tensor, output_hidden_states=True, **kwargs) # Check vision hidden feature layer's shape matches the expected hidden size @@ -620,6 +619,50 @@ def test_phi3_v(self): (config.vision_config.image_size, config.vision_config.image_size), ) + def test_pixtral(self): + from mlx_vlm.models import pixtral + + text_config = pixtral.TextConfig( + model_type="mistral", + hidden_size=4096, + num_hidden_layers=32, + intermediate_size=11008, + num_attention_heads=32, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=32, + rope_theta=10000.0, + rope_traditional=False, + rope_scaling=None, + ) + + vision_config = pixtral.VisionConfig( + model_type="pixtral", + num_hidden_layers=24, + hidden_size=1024, + intermediate_size=4096, + num_attention_heads=16, + image_size=336, + patch_size=14, + projection_dim=768, + vocab_size=32000, + num_channels=3, + rms_norm_eps=1e-6, + ) + + config = pixtral.ModelConfig( + text_config=text_config, + vision_config=vision_config, + model_type="pixtral", + ignore_index=-100, + image_token_index=32000, + vocab_size=32000, + vision_feature_layer=-2, + vision_feature_select_strategy="default", + ) + + model = pixtral.Model(config) + def test_qwen2_vl(self): from mlx_vlm.models import qwen2_vl @@ -656,7 +699,6 @@ def test_qwen2_vl(self): model_type="qwen2_vl", text_config=text_config, vision_config=vision_config, - rope_scaling=text_config.rope_scaling, image_token_index=151655, vocab_size=32000, ) diff --git a/mlx_vlm/tests/test_trainer.py b/mlx_vlm/tests/test_trainer.py new file mode 100644 index 0000000..519bfe4 --- /dev/null +++ b/mlx_vlm/tests/test_trainer.py @@ -0,0 +1,133 @@ +import unittest +from unittest.mock import MagicMock, patch + +import mlx.core as mx +import mlx.nn as nn + +from mlx_vlm.trainer.trainer import Dataset, Trainer, TrainingArgs +from mlx_vlm.utils import prepare_inputs + + +class TestDataset(unittest.TestCase): + def setUp(self): + self.mock_hf_dataset = MagicMock() + self.mock_config = {"model_type": "test_model", "image_token_index": 1} + self.mock_processor = MagicMock() + self.mock_image_processor = MagicMock() + + @patch("mlx_vlm.utils.prepare_inputs") + def test_dataset_initialization(self, mock_prepare_inputs): + dataset = Dataset( + self.mock_hf_dataset, + self.mock_config, + self.mock_processor, + self.mock_image_processor, + take=10, + split="train", + ) + + self.assertEqual(len(dataset), len(self.mock_hf_dataset["train"].take(10))) + self.assertEqual(dataset.config, self.mock_config) + self.assertEqual(dataset.processor, self.mock_processor) + self.assertEqual(dataset.image_processor, self.mock_image_processor) + + @patch("mlx_vlm.trainer.trainer.get_prompt") + @patch("mlx_vlm.utils.prepare_inputs") + def test_dataset_getitem(self, mock_prepare_inputs, mock_get_prompt): + dataset = Dataset( + self.mock_hf_dataset, + self.mock_config, + self.mock_processor, + self.mock_image_processor, + ) + + mock_item = { + "images": ["image1.jpg"], + "messages": [{"role": "user", "content": "Hello"}], + } + self.mock_hf_dataset.__getitem__.return_value = mock_item + + mock_get_prompt.return_value = "Mocked prompt" + + mock_prepare_inputs.return_value = ( + mx.array([1, 2, 3]), # input_ids + mx.array( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]] + ), # pixel_values + mx.array([1, 1, 1]), # mask + (1, 1, 1), # image_grid_thw + [224, 224], # image_sizes + ) + + result = dataset[0] + + mock_prepare_inputs.assert_called_once() + self.assertIn("pixel_values", result) + self.assertIn("input_ids", result) + self.assertIn("attention_mask", result) + self.assertIn("image_grid_thw", result) + self.assertIn("image_sizes", result) + + # Check if the returned values match the mocked input + self.assertTrue(mx.array_equal(result["input_ids"], mx.array([1, 2, 3]))) + self.assertTrue( + mx.array_equal( + result["pixel_values"], + mx.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]), + ) + ) + self.assertTrue(mx.array_equal(result["attention_mask"], mx.array([1, 1, 1]))) + self.assertEqual(result["image_grid_thw"], (1, 1, 1)) + self.assertEqual(result["image_sizes"], [224, 224]) + + +class TestTrainer(unittest.TestCase): + def setUp(self): + self.mock_model = MagicMock(spec=nn.Module) + self.mock_optimizer = MagicMock() + self.trainer = Trainer(self.mock_model, self.mock_optimizer) + + def test_trainer_initialization(self): + self.assertEqual(self.trainer.model, self.mock_model) + self.assertEqual(self.trainer.optimizer, self.mock_optimizer) + self.assertFalse(self.trainer.train_on_completions) + self.assertEqual(self.trainer.assistant_id, 77091) + + @patch("mlx.nn.losses.cross_entropy") + def test_loss_fn(self, mock_cross_entropy): + batch = { + "pixel_values": mx.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]), + "input_ids": mx.array([[1, 2, 3], [4, 5, 6]]), + "attention_mask": mx.array([[1, 1, 1], [1, 1, 0]]), + "image_grid_thw": (1, 1, 1), + "image_sizes": [224, 224], + } + + self.mock_model.return_value = mx.array([[[0.1, 0.2, 0.3]], [[0.4, 0.5, 0.6]]]) + mock_cross_entropy.return_value = mx.array([[0.1, 0.2], [0.3, 0.4]]) + + loss = self.trainer.loss_fn(self.mock_model, batch) + + self.assertIsInstance(loss, mx.array) + self.assertEqual(loss.shape, ()) # Scalar value + + @patch.object(Trainer, "loss_fn") + @patch("mlx.nn.value_and_grad") + def test_train_step(self, mock_value_and_grad, mock_loss_fn): + mock_batch = MagicMock() + mock_loss = mx.array(0.5) + mock_grads = {"param1": mx.array([0.1, 0.2]), "param2": mx.array([0.3, 0.4])} + + mock_value_and_grad.return_value = lambda *args, **kwargs: ( + mock_loss, + mock_grads, + ) + + loss = self.trainer.train_step(mock_batch) + + self.mock_optimizer.update.assert_called_once_with(self.mock_model, mock_grads) + self.assertEqual(loss, mock_loss) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlx_vlm/tests/test_trainer_utils.py b/mlx_vlm/tests/test_trainer_utils.py new file mode 100644 index 0000000..c7e344c --- /dev/null +++ b/mlx_vlm/tests/test_trainer_utils.py @@ -0,0 +1,59 @@ +import unittest +from unittest.mock import MagicMock, patch + +import mlx.nn as nn + +from mlx_vlm.trainer.utils import ( + find_all_linear_names, + get_module_by_name, + get_peft_model, + set_module_by_name, +) + + +class TestTrainerUtils(unittest.TestCase): + + def test_get_module_by_name(self): + model = MagicMock() + model.layer1.layer2.layer3 = "test_module" + + result = get_module_by_name(model, "layer1.layer2.layer3") + self.assertEqual(result, "test_module") + + def test_set_module_by_name(self): + model = MagicMock() + new_module = MagicMock() + + set_module_by_name(model, "layer1.layer2.layer3", new_module) + self.assertEqual(model.layer1.layer2.layer3, new_module) + + @patch("mlx_vlm.trainer.utils.freeze_model") + @patch("mlx_vlm.trainer.utils.print_trainable_parameters") + def test_get_peft_model(self, mock_print, mock_freeze): + model = MagicMock() + model.language_model.named_modules.return_value = [ + ("layer1", nn.Linear(256, 512)), + ("layer2", nn.QuantizedLinear(256, 512, 8)), + ] + + result = get_peft_model(model, ["layer1", "layer2"]) + + self.assertTrue(mock_freeze.called) + self.assertTrue(mock_print.called) + self.assertTrue(hasattr(model.config, "lora")) + + def test_find_all_linear_names(self): + model = MagicMock() + model.named_modules.return_value = [ + ("layer1", nn.Linear(256, 512)), + ("layer2", nn.QuantizedLinear(256, 512, 8)), + ("mm_projector", nn.Linear(256, 512)), + ("lm_head", nn.Linear(256, 512)), + ] + + result = find_all_linear_names(model) + self.assertEqual(set(result), {"layer1", "layer2"}) + + +if __name__ == "__main__": + unittest.main() diff --git a/mlx_vlm/trainer/__init__.py b/mlx_vlm/trainer/__init__.py new file mode 100644 index 0000000..813f6d1 --- /dev/null +++ b/mlx_vlm/trainer/__init__.py @@ -0,0 +1,9 @@ +from .lora import LoRaLayer, replace_lora_with_linear +from .trainer import Dataset, Trainer, save_adapter +from .utils import ( + apply_lora_layers, + count_parameters, + find_all_linear_names, + get_peft_model, + print_trainable_parameters, +) diff --git a/mlx_vlm/trainer/lora.py b/mlx_vlm/trainer/lora.py new file mode 100644 index 0000000..c139e47 --- /dev/null +++ b/mlx_vlm/trainer/lora.py @@ -0,0 +1,70 @@ +import math +from typing import Union + +import mlx.core as mx +import mlx.nn as nn + + +class LoRaLayer(nn.Module): + def __init__( + self, + linear: Union[nn.Linear, nn.QuantizedLinear], + rank: int, + alpha: float = 0.1, + dropout: float = 0.0, + ): + super().__init__() + + self.original_layer = linear + + self.dropout = nn.Dropout(p=dropout) + + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + + std_dev = 1 / math.sqrt(rank) + + self.A = mx.random.uniform( + low=-std_dev, + high=std_dev, + shape=(input_dims, rank), + ) + self.B = mx.zeros((rank, output_dims)) + self.alpha = alpha + + def __call__(self, x): + y = self.original_layer(x) + lora_update = (self.dropout(x) @ self.A) @ self.B + return y + (self.alpha * lora_update).astype(x.dtype) + + +def replace_lora_with_linear(model): + for i, layer in enumerate(model.layers): + if isinstance(layer, LoRaLayer): + # Compute the final merged weight + lora_update = layer.alpha * (layer.A @ layer.B) + updated_weight = layer.original_layer.weight + lora_update + use_bias = layer.original_layer.bias is not None + + updated_bias = layer.original_layer.bias + + # Create a new Linear layer with the updated parameters + new_linear_layer = nn.Linear( + updated_weight.size(1), updated_weight.size(0), bias=use_bias + ) + + new_linear_layer.weight = updated_weight + + if use_bias: + new_linear_layer.bias = updated_bias + + if isinstance(layer.original_layer, nn.QuantizedLinear): + new_linear_layer = nn.QuantizedLinear.from_linear( + new_linear_layer, + new_linear_layer.group_size, + new_linear_layer.bits, + ) + + # Replace the LoRaLayer with the new Linear layer in the model + model.layers[i] = new_linear_layer diff --git a/mlx_vlm/trainer/trainer.py b/mlx_vlm/trainer/trainer.py new file mode 100644 index 0000000..f02674e --- /dev/null +++ b/mlx_vlm/trainer/trainer.py @@ -0,0 +1,275 @@ +import json +import os +import time +import warnings +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + + +def get_prompt(model_type, processor, conversation): + if model_type == "paligemma": + return conversation + + if "chat_template" in processor.__dict__.keys(): + prompt = processor.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=False, + ) + elif "tokenizer" in processor.__dict__.keys(): + prompt = processor.tokenizer.apply_chat_template( + conversation, + tokenize=False, + add_generation_prompt=False, + ) + + return prompt + + +class Dataset: + def __init__( + self, + hf_dataset, + config, + processor, + image_processor=None, + take=None, + split=None, + ): + if split is not None: + self.dataset = hf_dataset[split] + else: + self.dataset = hf_dataset + if take is not None: + self.dataset = self.dataset.take(take) + self.processor = processor + self.config = config + self.image_processor = image_processor + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + from mlx_vlm.utils import prepare_inputs + + item = self.dataset[idx] + + images = item["images"] + conversations = item["messages"] + prompts = [] + + if isinstance(conversations, list) and isinstance(conversations[0], list): + for conversation in conversations: + if self.config["model_type"] == "pixtral": + conversation = [json.loads(i) for i in conversation] + if len(conversations) > 1: + warnings.warn( + "Pixtral batch processing is not supported yet. Set batch size to 1." + ) + + prompt = get_prompt( + self.config["model_type"], self.processor, conversation + ) + prompts.append(prompt) + + else: + if self.config["model_type"] == "pixtral": + conversations = [json.loads(i) for i in conversations] + prompt = get_prompt( + self.config["model_type"], self.processor, conversations + ) + prompts.append(prompt) + + image_token_index = self.config["image_token_index"] + + inputs = prepare_inputs( + self.image_processor, self.processor, images, prompts, image_token_index + ) + input_ids, pixel_values, mask = inputs[:3] + kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} + if mask is None: + mask = mx.ones_like(input_ids) + + return { + "pixel_values": pixel_values, + "input_ids": input_ids, + "attention_mask": mask, + **kwargs, + } + + +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + + +@dataclass +class TrainingArgs: + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapters.safetensors", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks + + +class Trainer: + def __init__( + self, model, optimizer, train_on_completions=False, assistant_id=77091 + ): + self.model = model + self.optimizer = optimizer + self.train_on_completions = train_on_completions + self.assistant_id = assistant_id + + def loss_fn(self, model, batch): + pixel_values = batch["pixel_values"] + input_ids = batch["input_ids"] + attention_mask = batch["attention_mask"] + lengths = mx.sum(attention_mask, axis=1) + labels = input_ids[:, 1:] + + batch_size, seq_length = input_ids.shape + + if self.train_on_completions: + weight_mask = mx.ones_like(attention_mask) + + assistant_response_index = np.where(input_ids == self.assistant_id)[1] + range_matrix = mx.repeat( + mx.expand_dims(mx.arange(seq_length), 0), batch_size, axis=0 + ) + assistant_mask = range_matrix <= mx.array(assistant_response_index).reshape( + -1, 1 + ) + # Apply the mask to weight_mask + weight_mask = mx.where( + assistant_mask, mx.zeros_like(weight_mask), weight_mask + )[:, 1:] + else: + weight_mask = None + + input_ids = input_ids[:, :-1] + + kwargs = ( + { + "image_grid_thw": batch["image_grid_thw"], + "image_sizes": batch["image_sizes"], + } + if "image_grid_thw" in batch or "image_sizes" in batch + else {} + ) + + # Forward pass + logits = model(input_ids, pixel_values, attention_mask, **kwargs) + + # Cast to float32 + logits.astype(mx.float32) + + # Ensure logits and labels have the same sequence length + def align_logits_with_labels(logits, labels): + if logits.shape[1] < labels.shape[1]: + pad_length = labels.shape[1] - logits.shape[1] + pad_width = ((0, 0), (0, pad_length), (0, 0)) + return mx.pad(logits, pad_width, mode="constant", constant_values=-100) + elif logits.shape[1] > labels.shape[1]: + return logits[:, -labels.shape[1] :, :] + return logits + + logits = align_logits_with_labels(logits, labels) + + length_mask = mx.arange(input_ids.shape[1])[None, :] < lengths[:, None] + + # Compute loss only on non-padded tokens + ce = ( + nn.losses.cross_entropy( + logits, + labels, + weights=weight_mask, + ) + * length_mask + ) + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce + + def train_step(self, batch): + loss_and_grad_fn = nn.value_and_grad(self.model, self.loss_fn) + loss, grads = loss_and_grad_fn(self.model, batch) + self.optimizer.update(self.model, grads) + + return loss + + @mx.compile + def train_epoch(self, dataloader): + total_loss = 0 + for batch in dataloader: + loss = self.train_step(batch) + mx.eval(self.model, self.optimizer.state) + total_loss += loss + return total_loss / len(dataloader) + + +def save_adapter( + model: nn.Module, + adapter_file: Union[str, Path], +): + path = Path(adapter_file) + if hasattr(model.config, "lora"): + with open(path.parent / "adapter_config.json", "w") as f: + json.dump(model.config.lora, f) + flattened_tree = tree_flatten(model.trainable_parameters()) + mx.save_safetensors(str(adapter_file), dict(flattened_tree)) diff --git a/mlx_vlm/trainer/utils.py b/mlx_vlm/trainer/utils.py new file mode 100644 index 0000000..9873ed7 --- /dev/null +++ b/mlx_vlm/trainer/utils.py @@ -0,0 +1,150 @@ +from pathlib import Path + +import mlx.nn as nn +from mlx.utils import tree_flatten + +from .lora import LoRaLayer + + +def get_module_by_name(model, name): + parts = name.split(".") + module = model + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + return module + + +def set_module_by_name(model, name, new_module): + parts = name.split(".") + module = model + for part in parts[:-1]: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + if parts[-1].isdigit(): + module[int(parts[-1])] = new_module + else: + setattr(module, parts[-1], new_module) + + +def get_peft_model( + model, linear_layers, rank=10, alpha=0.1, dropout=0.1, freeze=True, verbose=True +): + if freeze: + freeze_model(model) + + for name, module in model.language_model.named_modules(): + if isinstance(module, nn.Linear) or isinstance(module, nn.QuantizedLinear): + if name.split(".")[-1] in linear_layers: + lora_layer = LoRaLayer(module, rank, alpha, dropout) + set_module_by_name(model.language_model, name, lora_layer) + + model.config.lora = {} + model.config.lora["rank"] = rank + model.config.lora["alpha"] = alpha + model.config.lora["dropout"] = dropout + + if verbose: + print_trainable_parameters(model.language_model) + + return model + + +def freeze_model(model): + for name, module in model.named_modules(): + name = name.split(".")[0] + if name in [ + "language_model", + "vision_model", + "vision_tower", + "aligner", + "connector", + "multi_modal_projector", + "mm_projector", + ]: + model[f"{name}"].freeze() + + +def find_all_linear_names(model): + cls = nn.Linear + quantized_cls = nn.QuantizedLinear + lora_module_names = set() + multimodal_keywords = [ + "mm_projector", + "vision_tower", + "vision_resampler", + "aligner", + ] + for name, module in model.named_modules(): + if any(mm_keyword in name for mm_keyword in multimodal_keywords): + continue + if isinstance(module, cls) or isinstance(module, quantized_cls): + names = name.split(".") + lora_module_names.add(names[0] if len(names) == 1 else names[-1]) + + if "lm_head" in lora_module_names: # needed for 16-bit + lora_module_names.remove("lm_head") + return list(lora_module_names) + + +def count_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + + return total_p + + +def print_trainable_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + trainable_p = ( + sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + ) + + print( + f"#trainable params: {trainable_p} M || all params: {total_p} M || trainable%: {(trainable_p * 100 / total_p):.3f}%" + ) + + +def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: + """ + Apply LoRA layers to the model. + + Args: + model (nn.Module): The neural network model. + adapter_path (str): Path to the adapter configuration file. + + Returns: + nn.Module: The updated model with LoRA layers applied. + """ + adapter_path = Path(adapter_path) + + if not adapter_path.exists(): + raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") + + # TODO: add lora params to the config and load them here + list_of_modules = find_all_linear_names(model.language_model.model) + model = get_peft_model(model, list_of_modules) + + # TODO: Use custom adapter name + model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) + + return model diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 9545311..186bd01 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -26,7 +26,8 @@ from .models.base import BaseImageProcessor, KVCache from .sample_utils import top_p_sampling -from .tokenizer_utils import TokenizerWrapper, load_tokenizer +from .tokenizer_utils import load_tokenizer +from .trainer import apply_lora_layers # Constants MODEL_REMAPPING = {"llava-qwen2": "llava_bunny", "bunny-llama": "llava_bunny"} @@ -223,6 +224,7 @@ def load_model(model_path: Path, lazy: bool = False) -> nn.Module: def load( path_or_hf_repo: str, processor_config={}, + adapter_path: Optional[str] = None, lazy: bool = False, ) -> Tuple[nn.Module, Union[PreTrainedTokenizer, PreTrainedTokenizerFast]]: """ @@ -247,6 +249,11 @@ def load( model_path = get_model_path(path_or_hf_repo) model = load_model(model_path, lazy) + if adapter_path is not None: + # TODO: Support more modules than just language_model + model = apply_lora_layers(model, adapter_path) + model.eval() + processor = load_processor(model_path, processor_config=processor_config) return model, processor @@ -288,14 +295,15 @@ def load_image_processor(model_path: Union[str, Path]) -> BaseImageProcessor: def load_processor( - model_path, processor_config={"trust_remote_code": True} + model_path, processor_config={"trust_remote_code": True}, add_detokenizer=True ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: processor = AutoProcessor.from_pretrained(model_path, **processor_config) - detokenizer_class = load_tokenizer(model_path, return_tokenizer=False) - if "tokenizer" in processor.__dict__.keys(): - processor.detokenizer = detokenizer_class(processor.tokenizer) - else: - processor.detokenizer = detokenizer_class(processor) + if add_detokenizer: + detokenizer_class = load_tokenizer(model_path, return_tokenizer=False) + if "tokenizer" in processor.__dict__.keys(): + processor.detokenizer = detokenizer_class(processor.tokenizer) + else: + processor.detokenizer = detokenizer_class(processor) return processor @@ -304,8 +312,7 @@ def fetch_from_hub( ) -> Tuple[nn.Module, dict, PreTrainedTokenizer]: model = load_model(model_path, lazy) config = load_config(model_path) - processor = load_processor(model_path) - + processor = load_processor(model_path, add_detokenizer=False) return model, config, processor @@ -637,7 +644,7 @@ def convert( ): print("[INFO] Loading") model_path = get_model_path(hf_path, revision=revision) - model, config, tokenizer = fetch_from_hub(model_path, lazy=False) + model, config, processor = fetch_from_hub(model_path, lazy=False) weights = dict(tree_flatten(model.parameters())) dtype = mx.float16 if quantize else getattr(mx, dtype) @@ -666,7 +673,7 @@ def convert( for file in py_files: shutil.copy(file, mlx_path) - tokenizer.save_pretrained(mlx_path) + processor.save_pretrained(mlx_path) save_config(config, config_path=mlx_path / "config.json") @@ -704,46 +711,86 @@ def load_image(image_source: Union[str, Path, BytesIO]): ) -def prepare_inputs(image_processor, processor, image, prompt, image_token_index): +def resize_image(img, max_size): + ratio = min(max_size[0] / img.width, max_size[1] / img.height) + new_size = (int(img.width * ratio), int(img.height * ratio)) + return img.resize(new_size) + + +def process_image(img, resize_shape): + if isinstance(img, str): + img = load_image(img) + if resize_shape is not None: + img = resize_image(img, resize_shape) + return img + + +def prepare_inputs( + image_processor, processor, images, prompts, image_token_index, resize_shape=None +): from transformers.image_utils import load_image mask = None - if isinstance(image, str): - image = load_image(image) + if not isinstance(images, list): + images = [images] + + # Process images + images = [ + process_image(img, resize_shape) if isinstance(img, str) else img + for img in images + ] image_grid_thw = None + image_sizes = None if image_processor is not None: - text_chunks = [processor(chunk).input_ids for chunk in prompt.split("")] - input_ids = mx.array([text_chunks[0] + [image_token_index] + text_chunks[1]]) - pixel_values = mx.array(image_processor.preprocess(images=[image])[0]) - pixel_values = mx.array(mx.expand_dims(pixel_values, axis=0)) + if not isinstance(prompts, list): + prompts = [prompts] + + processor.pad_token = processor.eos_token + text_chunks = [ + [processor(chunk).input_ids for chunk in prompt.split("")] + for prompt in prompts + ] + + # Find the maximum length for padding + max_length = max( + sum(len(chunk) for chunk in chunks) + 1 for chunks in text_chunks + ) + + # Pad and create input_ids + input_ids = [] + for chunks in text_chunks: + ids = chunks[0] + [image_token_index] + chunks[1] + padding = [processor.pad_token_id] * (max_length - len(ids)) + input_ids.append(mx.array(ids + padding)) + + input_ids = mx.array(input_ids) + + pixel_values = image_processor.preprocess(images=images) + pixel_values = mx.array(np.stack(pixel_values)) + + mask = mx.array([(ids != processor.pad_token_id) for ids in input_ids]).astype( + mx.int32 + ) else: processor.tokenizer.pad_token = processor.tokenizer.eos_token - try: - inputs = processor( - text=[prompt], images=[image], padding=True, return_tensors="mlx" - ) - except Exception as e: - inputs = processor( - text=prompt, images=[image], padding=True, return_tensors="mlx" - ) # for phi3_v model - + inputs = processor( + text=prompts, images=images, padding=True, return_tensors="mlx" + ) if isinstance(inputs["pixel_values"], list): - pixel_values = mx.array(inputs["pixel_values"][0][0])[None, :] - elif isinstance(inputs["pixel_values"], np.ndarray): - pixel_values = mx.array(inputs["pixel_values"]) + pixel_values = inputs["pixel_values"] else: - raise ValueError( - f"Invalid pixel_values type: {type(inputs['pixel_values'])}" - ) - + pixel_values = mx.array(inputs["pixel_values"]) input_ids = mx.array(inputs["input_ids"]) - mask = inputs["attention_mask"] + mask = mx.array(inputs["attention_mask"]) + image_sizes = inputs.get("image_sizes", None) + if image_sizes is not None: + image_sizes = mx.array(image_sizes) image_grid_thw = inputs.get("image_grid_thw", None) - if "image_sizes" in inputs: - return input_ids, pixel_values, inputs["image_sizes"], image_grid_thw + if image_grid_thw is not None: + image_grid_thw = mx.array(image_grid_thw) - return input_ids, pixel_values, mask, image_grid_thw + return input_ids, pixel_values, mask, image_grid_thw, image_sizes def generate_step( @@ -878,9 +925,11 @@ def stream_generate( tokenizer = processor.tokenizer image_token_index = model.config.image_token_index - input_ids, pixel_values, mask = prepare_inputs( + inputs = prepare_inputs( image_processor, processor, image, prompt, image_token_index ) + input_ids, pixel_values, mask = inputs[:3] + kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} detokenizer = processor.detokenizer @@ -944,32 +993,32 @@ def generate( tokenizer = processor.tokenizer image_token_index = model.config.image_token_index - input_ids, pixel_values, mask, image_grid_thw = prepare_inputs( + # Prepare inputs + inputs = prepare_inputs( image_processor, processor, image, prompt, image_token_index ) + input_ids, pixel_values, mask = inputs[:3] + kwargs = {k: v for k, v in zip(["image_grid_thw", "image_sizes"], inputs[3:])} - kwargs = { - "image_grid_thw": image_grid_thw, - } - + # Initialize timing and detokenizer tic = time.perf_counter() detokenizer = processor.detokenizer detokenizer.reset() - for (token, prob), n in zip( - generate_step( - input_ids, - model, - pixel_values, - mask, - temp, - repetition_penalty, - repetition_context_size, - top_p, - **kwargs, - ), - range(max_tokens), - ): + # Generate tokens + generator = generate_step( + input_ids, + model, + pixel_values, + mask, + temp, + repetition_penalty, + repetition_context_size, + top_p, + **kwargs, + ) + + for (token, prob), n in zip(generator, range(max_tokens)): if n == 0: prompt_time = time.perf_counter() - tic diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index 6561790..3dc1f76 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.0.15" +__version__ = "0.1.0" diff --git a/requirements.txt b/requirements.txt index d27c1b4..2723e00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,9 @@ mlx>=0.18.0 -numpy +datasets>=2.19.1 +tqdm>=4.66.2 +numpy>=1.23.4 transformers>=4.45.1 scipy==1.13.1 gradio>=4.44.0 -Pillow -requests +Pillow>=10.3.0 +requests>=2.31.0