Skip to content

Commit

Permalink
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
Browse files Browse the repository at this point in the history
…/tuner
  • Loading branch information
Blaizzy committed Jun 23, 2024
2 parents 7798682 + 2391df4 commit f5613eb
Show file tree
Hide file tree
Showing 5 changed files with 285 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mlx_vlm/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .lora import LoRaLayer, replace_lora_with_linear
from .utils import (
collate_fn,
count_parameters,
find_all_linear_names,
get_peft_model,
print_trainable_parameters,
)
68 changes: 68 additions & 0 deletions mlx_vlm/trainer/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import math
from typing import Union

import mlx.core as mx
import mlx.nn as nn


class LoRaLayer(nn.Module):
def __init__(
self,
linear: Union[nn.Linear, nn.QuantizedLinear],
rank: int,
alpha: float = 0.1,
dropout: float = 0.0,
):
super().__init__()

self.original_layer = linear

self.dropout = nn.Dropout(p=dropout)

output_dims, input_dims = linear.weight.shape

std_dev = 1 / math.sqrt(rank)

self.A = mx.random.uniform(
low=-std_dev,
high=std_dev,
shape=(input_dims, rank),
)
self.B = mx.zeros((rank, output_dims))
self.alpha = alpha

def __call__(self, x):
y = self.original_layer(x)
lora_update = (self.dropout(x) @ self.A) @ self.B
return y + (self.alpha * lora_update).astype(x.dtype)


def replace_lora_with_linear(model):
for i, layer in enumerate(model.layers):
if isinstance(layer, LoRaLayer):
# Compute the final merged weight
lora_update = layer.alpha * (layer.A @ layer.B)
updated_weight = layer.original_layer.weight + lora_update
use_bias = layer.original_layer.bias is not None

updated_bias = layer.original_layer.bias

# Create a new Linear layer with the updated parameters
new_linear_layer = nn.Linear(
updated_weight.size(1), updated_weight.size(0), bias=use_bias
)

new_linear_layer.weight = updated_weight

if use_bias:
new_linear_layer.bias = updated_bias

if isinstance(layer.original_layer, nn.QuantizedLinear):
new_linear_layer = nn.QuantizedLinear.from_linear(
new_linear_layer,
new_linear_layer.group_size,
new_linear_layer.bits,
)

# Replace the LoRaLayer with the new Linear layer in the model
model.layers[i] = new_linear_layer
70 changes: 70 additions & 0 deletions mlx_vlm/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten


def grad_checkpoint(layer):
"""
Update all instances of type(layer) to use gradient checkpointing.
"""
fn = type(layer).__call__

def checkpointed_fn(model, *args, **kwargs):
def inner_fn(params, *args, **kwargs):
model.update(params)
return fn(model, *args, **kwargs)

return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs)

type(layer).__call__ = checkpointed_fn

@dataclass
class TrainingArgs:
batch_size: int = field(default=4, metadata={"help": "Minibatch size."})
iters: int = field(default=100, metadata={"help": "Iterations to train for."})
val_batches: int = field(
default=25,
metadata={
"help": "Number of validation batches, -1 uses the entire validation set."
},
)
steps_per_report: int = field(
default=10,
metadata={"help": "Number of training steps between loss reporting."},
)
steps_per_eval: int = field(
default=200, metadata={"help": "Number of training steps between validations."}
)
steps_per_save: int = field(
default=100, metadata={"help": "Save the model every number steps"}
)
max_seq_length: int = field(
default=2048, metadata={"help": "Maximum sequence length."}
)
adapter_file: str = field(
default="adapters.safetensors",
metadata={"help": "Save/load path for the trained adapter weights."},
)
grad_checkpoint: bool = field(
default=False,
metadata={"help": "Use gradient checkpointing to reduce memory use."},
)


def default_loss(model, inputs, targets, lengths):
logits = model(inputs)
logits = logits.astype(mx.float32)

length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None]

ce = nn.losses.cross_entropy(logits, targets) * length_mask
ntoks = length_mask.sum()
ce = ce.sum() / ntoks

return ce, ntoks
138 changes: 138 additions & 0 deletions mlx_vlm/trainer/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import mlx.core as mx
import mlx.nn as nn
import numpy as np

from .lora import LoRaLayer


def get_module_by_name(model, name):
parts = name.split(".")
module = model
for part in parts:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
return module


def set_module_by_name(model, name, new_module):
parts = name.split(".")
module = model
for part in parts[:-1]:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
if parts[-1].isdigit():
module[int(parts[-1])] = new_module
else:
setattr(module, parts[-1], new_module)


def get_peft_model(model, linear_layers, freeze=True, verbose=True):
source_model_trainable = count_parameters(
model.language_model.trainable_parameters()
)

if freeze:
freeze_model(model)

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and name.split(".")[-1] in linear_layers:
lora_layer = LoRaLayer(module, 10, 0.1, 0.1)
set_module_by_name(model, name, lora_layer)

lora_model_trainable = count_parameters(model.language_model.trainable_parameters())
if verbose:
print_trainable_parameters(source_model_trainable, lora_model_trainable)

return model


def freeze_model(model):
for name, module in model.named_modules():
if name in [
"language_model",
"vision_model",
"vision_tower",
"aligner",
"connector",
"multi_modal_projector",
"mm_projector",
]:
model[f"{name}"].freeze()


def find_all_linear_names(model):
cls = nn.Linear
lora_module_names = set()
multimodal_keywords = [
"mm_projector",
"vision_tower",
"vision_resampler",
"aligner",
]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])

if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)


def collate_fn(processor, examples):
texts = ["answer " + example["question"] for example in examples]
labels = [example["multiple_choice_answer"] for example in examples]
images = [example["image"].convert("RGB") for example in examples]
tokens = processor(
text=texts,
images=images,
suffix=labels,
return_tensors="np",
padding="longest",
tokenize_newline_separately=False,
)

tokens = tokens.to(mx.float16)
return tokens


def flatten_dict(dd, separator="_", prefix=""):
return (
{
prefix + separator + k if prefix else k: v
for kk, vv in dd.items()
for k, v in flatten_dict(vv, separator, kk).items()
}
if isinstance(dd, dict)
else {prefix: dd}
)


def count_parameters(trainable_params_dict):
total_params = 0
for k, v in flatten_dict(trainable_params_dict).items():
if hasattr(v, "shape"):
total_params += np.prod(v.shape)

if isinstance(v, list):
for v_ in v:
v_ = flatten_dict(v_)
if isinstance(v_, dict):
total_params += sum(
np.prod(p.shape) for p in v_.values() if hasattr(p, "shape")
)

return total_params


def print_trainable_parameters(source_model_trainable, lora_model_trainable):
lora_trainable_percent = (lora_model_trainable / source_model_trainable) * 100
print(
f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}"
)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mlx>=0.14
mlx>=0.14
numpy
transformers>=4.39.3
scipy==1.13.1
Expand Down

0 comments on commit f5613eb

Please sign in to comment.