Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer + Multi image v0.1.0 #41

Merged
merged 79 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
2f50f10
remove torch and mlx-lm
Blaizzy May 26, 2024
d14849f
remove torch and mlx-lm
Blaizzy May 26, 2024
2c72233
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Blaizzy Jun 11, 2024
2391df4
add peft model creation
Blaizzy Jun 12, 2024
f5613eb
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Blaizzy Jun 23, 2024
5fcaed2
use tree flatten
Blaizzy Jul 7, 2024
a88029f
add dataset loader
Blaizzy Jul 9, 2024
3d29a20
Merge branch 'main' into pc/tuner
Blaizzy Sep 3, 2024
9aa5072
fix dataset
Blaizzy Sep 3, 2024
3c4df2a
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Blaizzy Sep 3, 2024
911eaaa
fix masks and rename dataset
Blaizzy Sep 4, 2024
8fa9bb9
support batch processing and train on completions
Blaizzy Sep 7, 2024
bf9bed6
fix trainer
Blaizzy Sep 16, 2024
f00252d
formatting
Blaizzy Sep 16, 2024
f206ded
add support for none splits and fix assistant id
Blaizzy Sep 28, 2024
dab901c
Add lora script and docs
Blaizzy Sep 28, 2024
607b249
remove torch and mlx-lm
Blaizzy May 26, 2024
5c135ac
add peft model creation
Blaizzy Jun 12, 2024
534f20c
use tree flatten
Blaizzy Jul 7, 2024
c1edc22
add dataset loader
Blaizzy Jul 9, 2024
91e9305
fix dataset
Blaizzy Sep 3, 2024
e5c0424
fix masks and rename dataset
Blaizzy Sep 4, 2024
130d876
support batch processing and train on completions
Blaizzy Sep 7, 2024
5c028d4
fix trainer
Blaizzy Sep 16, 2024
1bc0aa4
formatting
Blaizzy Sep 16, 2024
d62bf63
add support for none splits and fix assistant id
Blaizzy Sep 28, 2024
a6d411e
Add lora script and docs
Blaizzy Sep 28, 2024
c5bff8d
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Blaizzy Sep 29, 2024
c1033b5
remove duplicates
Blaizzy Sep 29, 2024
80cdcd6
fix batch load
Blaizzy Sep 29, 2024
935e1bd
load trained adapters and add super to all models
Blaizzy Sep 29, 2024
8ba507d
fix pixtral quant
Blaizzy Sep 29, 2024
23598ad
speed up qwen batch processing
Blaizzy Sep 29, 2024
dc2226e
fix qlora training
Blaizzy Oct 1, 2024
4cb7956
fix dataloader
Blaizzy Oct 1, 2024
0659d45
formatting
Blaizzy Oct 2, 2024
1880162
fix pixtral pixel loading
Blaizzy Oct 2, 2024
858caab
fix lora and dataset
Blaizzy Oct 2, 2024
5e5ab71
add batch processing suppor for qwen2_vl
Blaizzy Oct 2, 2024
86baba3
update lora docs
Blaizzy Oct 2, 2024
4ce361d
add unit tests
Blaizzy Oct 2, 2024
eac9ee1
set stage for phi3_v support
Blaizzy Oct 2, 2024
4ebac21
update logs and readme
Blaizzy Oct 2, 2024
336f423
add utils tests and remove unused collate fn
Blaizzy Oct 2, 2024
c0cd42d
refactor prompt utils and add multi-image support for pixtral
Blaizzy Oct 4, 2024
5f26374
add llava interleave support
Blaizzy Oct 4, 2024
efa26e6
multi image support
Blaizzy Oct 4, 2024
e744722
add image resizing
Blaizzy Oct 4, 2024
49d16f6
refactor data loading
Blaizzy Oct 4, 2024
df99627
update data procesing and tqdm
Blaizzy Oct 4, 2024
b0a5bda
add llava interleave
Blaizzy Oct 4, 2024
941ebf8
formatting
Blaizzy Oct 4, 2024
7a58b96
add list of models with multi-image support
Blaizzy Oct 4, 2024
ca80b6c
remove trimmed labels
Blaizzy Oct 5, 2024
349e4d1
remove warning
Blaizzy Oct 5, 2024
028e32c
pin reqs
Blaizzy Oct 5, 2024
cd5ecf5
add config dict condition
Blaizzy Oct 5, 2024
a116169
fix pixtral FT prompt
Blaizzy Oct 5, 2024
d791dff
formatting images
Blaizzy Oct 5, 2024
c16c048
remove unused
Blaizzy Oct 5, 2024
5a9c3db
update trainer init
Blaizzy Oct 5, 2024
97a4255
update lora
Blaizzy Oct 5, 2024
0159020
update md and formatting
Blaizzy Oct 5, 2024
0ec2412
bump version
Blaizzy Oct 5, 2024
608adfc
add tests for pixtral and qwen2_vl
Blaizzy Oct 6, 2024
15962ec
add tests for pixtral
Blaizzy Oct 6, 2024
d669fd1
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Blaizzy Oct 6, 2024
b135eea
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Blaizzy Oct 6, 2024
98e0024
Merge branch 'main' into pc/tuner
Blaizzy Oct 6, 2024
b7daf46
fix test
Blaizzy Oct 6, 2024
726faca
remove rope scaling
Blaizzy Oct 6, 2024
a53fa13
remove test args and update MD
Blaizzy Oct 6, 2024
31cdd67
format dataset defaults
Blaizzy Oct 9, 2024
e33c0d2
add dataset formatting info
Blaizzy Oct 9, 2024
1f3eabd
Fix issues with multiple image handling (#78)
hiima234 Oct 11, 2024
a9488bb
fix styling
Blaizzy Oct 11, 2024
87e598f
update model
Blaizzy Oct 11, 2024
dde7390
update default model
Blaizzy Oct 11, 2024
abbe83f
rewrite comments
Blaizzy Oct 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ var/
.installed.cfg
*.egg
.DS_Store
*.log
110 changes: 93 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,119 @@
# MLX-VLM

MLX-VLM a package for running Vision LLMs on your Mac using MLX.
MLX-VLM is a package for inference and fine-tuning of Vision Language Models (VLMs) on your Mac using MLX.

## Table of Contents
- [Installation](#installation)
- [Usage](#usage)
- [Command Line Interface (CLI)](#command-line-interface-cli)
- [Chat UI with Gradio](#chat-ui-with-gradio)
- [Python Script](#python-script)
- [Multi-Image Chat Support](#multi-image-chat-support)
- [Supported Models](#supported-models)
- [Usage Examples](#usage-examples)
- [Fine-tuning](#fine-tuning)

## Get started
## Installation

The easiest way to get started is to install the `mlx-vlm` package:

**With `pip`**:
The easiest way to get started is to install the `mlx-vlm` package using pip:

```sh
pip install mlx-vlm
```

## Inference
## Usage

### Command Line Interface (CLI)

Generate output from a model using the CLI:

**CLI**
```sh
python -m mlx_vlm.generate --model qnguyen3/nanoLLaVA --max-tokens 100 --temp 0.0
python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --temp 0.0 --image http://images.cocodataset.org/val2017/000000039769.jpg
```

**Chat UI with Gradio**
### Chat UI with Gradio

Launch a chat interface using Gradio:

```sh
python -m mlx_vlm.chat_ui --model qnguyen3/nanoLLaVA
python -m mlx_vlm.chat_ui --model mlx-community/Qwen2-VL-2B-Instruct-4bit
```

**Script**
### Python Script

Here's an example of how to use MLX-VLM in a Python script:

```python
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template

# Load the model
model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit"
model, processor = load(model_path)

# Prepare input
image = ["http://images.cocodataset.org/val2017/000000039769.jpg"]
prompt = "Describe this image."

# Apply chat template
formatted_prompt = apply_chat_template(
processor, config, prompt, num_images=len(image)
)

# Generate output
output = generate(model, processor, image, formatted_prompt, verbose=False)
print(output)
```

## Multi-Image Chat Support

MLX-VLM supports analyzing multiple images simultaneously with select models. This feature enables more complex visual reasoning tasks and comprehensive analysis across multiple images in a single conversation.

### Supported Models

The following models support multi-image chat:

1. Idefics 2
2. LLaVA (Interleave)
3. Qwen2-VL
4. Phi3-Vision
5. Pixtral

### Usage Examples

#### Python Script

```python
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template

model_path = "mlx-community/llava-1.5-7b-4bit"
model_path = "mlx-community/Qwen2-VL-2B-Instruct-4bit"
model, processor = load(model_path)

prompt = processor.tokenizer.apply_chat_template(
[{"role": "user", "content": f"<image>\nWhat are these?"}],
tokenize=False,
add_generation_prompt=True,
images = ["path/to/image1.jpg", "path/to/image2.jpg"]
prompt = "Compare these two images."

formatted_prompt = apply_chat_template(
processor, config, prompt, num_images=len(images)
)

output = generate(model, processor, "http://images.cocodataset.org/val2017/000000039769.jpg", prompt, verbose=False)
output = generate(model, processor, images, formatted_prompt, verbose=False)
print(output)
```

#### Command Line

```sh
python -m mlx_vlm.generate --model mlx-community/Qwen2-VL-2B-Instruct-4bit --max-tokens 100 --prompt "Compare these images" --image path/to/image1.jpg path/to/image2.jpg
```

These examples demonstrate how to use multiple images with MLX-VLM for more complex visual reasoning tasks.

# Fine-tuning

MLX-VLM supports fine-tuning models with LoRA and QLoRA.

## LoRA & QLoRA

To learn more about LoRA, please refer to the [LoRA.md](./mlx_vlm/LoRA.md) file.
77 changes: 77 additions & 0 deletions mlx_vlm/LORA.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# LoRA Training Script

## Overview

`lora.py` is a Python script for fine-tuning a vision language models (VLMs) using Low-Rank Adaptation (LoRA or QLoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments.

## Requirements

- Python 3.7+
- Required Python packages: `mlx-vlm`, `numpy`, `transformers`, `datasets`, `PIL`

## Supported Models
- Qwen2
- LLaVA (except for LLaVA-Next)
- Pixtral
- Idefics 2
- Deepseek-VL
- Paligemma

## Coming Soon
- LLaVA-Next
- Phi3_vision

## Usage

To use the script, run it from the command line with the desired arguments:

```
python lora.py --dataset /path/to/your/dataset [other options]
```

## Dataset format

The dataset should be a Hugging Face dataset with a `images` column and a `messages` column.

```
{
"images": ...,
"messages": ...,
}
```

Support for other formats and column names will be added soon.

## Arguments

The script accepts the following command-line arguments:

- `--model_path`: Path to the pre-trained model (default: "mlx-community/Qwen2-VL-2B-Instruct-bf16")
- `--dataset`: Path to your dataset (required)
- `--learning_rate`: Learning rate for the optimizer (default: 1e-4)
- `--batch_size`: Batch size for training (default: 1)
- `--epochs`: Number of epochs to train (default: 1)
- `--steps`: Number of steps per epoch (default: 0)
- `--print_every`: Print loss every n steps (default: 10)
- `--output_path`: Path to save the trained adapter (default: "adapters.safetensors")

## Example

Here's an example of how to run the script with custom parameters:

```
python lora.py --dataset /path/to/your/dataset --model_path /path/to/your/model --epochs 2 --batch_size 4 --learning_rate 5e-5
```

## Output

The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path.

## Note

If you want to use QLoRA, you need to pass a pre-quantized model to the script using the `--model_path` argument (i.e. `mlx-community/Qwen2-VL-2B-Instruct-4bit`).
Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model.

## Contributing

Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements.
3 changes: 2 additions & 1 deletion mlx_vlm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .utils import convert, generate, load
from .prompt_utils import apply_chat_template, get_message_json
from .utils import convert, generate, load, prepare_inputs
from .version import __version__
27 changes: 19 additions & 8 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import argparse
import codecs

import mlx.core as mx

from .prompt_utils import apply_chat_template
from .utils import generate, get_model_path, load, load_config, load_image_processor

DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
DEFAULT_IMAGE = "http://images.cocodataset.org/val2017/000000039769.jpg"
DEFAULT_IMAGE = ["http://images.cocodataset.org/val2017/000000039769.jpg"]
DEFAULT_PROMPT = "What are these?"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.5
Expand All @@ -25,9 +23,16 @@ def parse_arguments():
default=DEFAULT_MODEL_PATH,
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--adapter-path",
type=str,
default=None,
help="The path to the adapter weights.",
)
parser.add_argument(
"--image",
type=str,
nargs="+",
default=DEFAULT_IMAGE,
help="URL or path of the image to process.",
)
Expand All @@ -50,22 +55,28 @@ def parse_arguments():
return parser.parse_args()


def get_model_and_processors(model_path):
def get_model_and_processors(model_path, adapter_path):
model_path = get_model_path(model_path)
config = load_config(model_path)
model, processor = load(model_path, {"trust_remote_code": True})
model, processor = load(
model_path, {"trust_remote_code": True}, adapter_path=adapter_path
)
image_processor = load_image_processor(model_path)
return model, processor, image_processor, config


def main():
args = parse_arguments()
model, processor, image_processor, config = get_model_and_processors(args.model)
if isinstance(args.image, str):
args.image = [args.image]

model, processor, image_processor, config = get_model_and_processors(
args.model, args.adapter_path
)

prompt = codecs.decode(args.prompt, "unicode_escape")

if model.config.model_type != "paligemma":
prompt = apply_chat_template(processor, config, prompt)
prompt = apply_chat_template(processor, config, prompt, num_images=len(args.image))

output = generate(
model,
Expand Down
Loading
Loading