-
-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'pc/tuner' of https://github.com/Blaizzy/mlx-vlm into pc…
…/tuner
- Loading branch information
Showing
5 changed files
with
285 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .lora import LoRaLayer, replace_lora_with_linear | ||
from .utils import ( | ||
collate_fn, | ||
count_parameters, | ||
find_all_linear_names, | ||
get_peft_model, | ||
print_trainable_parameters, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import math | ||
from typing import Union | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
|
||
|
||
class LoRaLayer(nn.Module): | ||
def __init__( | ||
self, | ||
linear: Union[nn.Linear, nn.QuantizedLinear], | ||
rank: int, | ||
alpha: float = 0.1, | ||
dropout: float = 0.0, | ||
): | ||
super().__init__() | ||
|
||
self.original_layer = linear | ||
|
||
self.dropout = nn.Dropout(p=dropout) | ||
|
||
output_dims, input_dims = linear.weight.shape | ||
|
||
std_dev = 1 / math.sqrt(rank) | ||
|
||
self.A = mx.random.uniform( | ||
low=-std_dev, | ||
high=std_dev, | ||
shape=(input_dims, rank), | ||
) | ||
self.B = mx.zeros((rank, output_dims)) | ||
self.alpha = alpha | ||
|
||
def __call__(self, x): | ||
y = self.original_layer(x) | ||
lora_update = (self.dropout(x) @ self.A) @ self.B | ||
return y + (self.alpha * lora_update).astype(x.dtype) | ||
|
||
|
||
def replace_lora_with_linear(model): | ||
for i, layer in enumerate(model.layers): | ||
if isinstance(layer, LoRaLayer): | ||
# Compute the final merged weight | ||
lora_update = layer.alpha * (layer.A @ layer.B) | ||
updated_weight = layer.original_layer.weight + lora_update | ||
use_bias = layer.original_layer.bias is not None | ||
|
||
updated_bias = layer.original_layer.bias | ||
|
||
# Create a new Linear layer with the updated parameters | ||
new_linear_layer = nn.Linear( | ||
updated_weight.size(1), updated_weight.size(0), bias=use_bias | ||
) | ||
|
||
new_linear_layer.weight = updated_weight | ||
|
||
if use_bias: | ||
new_linear_layer.bias = updated_bias | ||
|
||
if isinstance(layer.original_layer, nn.QuantizedLinear): | ||
new_linear_layer = nn.QuantizedLinear.from_linear( | ||
new_linear_layer, | ||
new_linear_layer.group_size, | ||
new_linear_layer.bits, | ||
) | ||
|
||
# Replace the LoRaLayer with the new Linear layer in the model | ||
model.layers[i] = new_linear_layer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import time | ||
from dataclasses import dataclass, field | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
import numpy as np | ||
from mlx.utils import tree_flatten | ||
|
||
|
||
def grad_checkpoint(layer): | ||
""" | ||
Update all instances of type(layer) to use gradient checkpointing. | ||
""" | ||
fn = type(layer).__call__ | ||
|
||
def checkpointed_fn(model, *args, **kwargs): | ||
def inner_fn(params, *args, **kwargs): | ||
model.update(params) | ||
return fn(model, *args, **kwargs) | ||
|
||
return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) | ||
|
||
type(layer).__call__ = checkpointed_fn | ||
|
||
@dataclass | ||
class TrainingArgs: | ||
batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) | ||
iters: int = field(default=100, metadata={"help": "Iterations to train for."}) | ||
val_batches: int = field( | ||
default=25, | ||
metadata={ | ||
"help": "Number of validation batches, -1 uses the entire validation set." | ||
}, | ||
) | ||
steps_per_report: int = field( | ||
default=10, | ||
metadata={"help": "Number of training steps between loss reporting."}, | ||
) | ||
steps_per_eval: int = field( | ||
default=200, metadata={"help": "Number of training steps between validations."} | ||
) | ||
steps_per_save: int = field( | ||
default=100, metadata={"help": "Save the model every number steps"} | ||
) | ||
max_seq_length: int = field( | ||
default=2048, metadata={"help": "Maximum sequence length."} | ||
) | ||
adapter_file: str = field( | ||
default="adapters.safetensors", | ||
metadata={"help": "Save/load path for the trained adapter weights."}, | ||
) | ||
grad_checkpoint: bool = field( | ||
default=False, | ||
metadata={"help": "Use gradient checkpointing to reduce memory use."}, | ||
) | ||
|
||
|
||
def default_loss(model, inputs, targets, lengths): | ||
logits = model(inputs) | ||
logits = logits.astype(mx.float32) | ||
|
||
length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] | ||
|
||
ce = nn.losses.cross_entropy(logits, targets) * length_mask | ||
ntoks = length_mask.sum() | ||
ce = ce.sum() / ntoks | ||
|
||
return ce, ntoks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import mlx.core as mx | ||
import mlx.nn as nn | ||
import numpy as np | ||
|
||
from .lora import LoRaLayer | ||
|
||
|
||
def get_module_by_name(model, name): | ||
parts = name.split(".") | ||
module = model | ||
for part in parts: | ||
if part.isdigit(): | ||
module = module[int(part)] | ||
else: | ||
module = getattr(module, part) | ||
return module | ||
|
||
|
||
def set_module_by_name(model, name, new_module): | ||
parts = name.split(".") | ||
module = model | ||
for part in parts[:-1]: | ||
if part.isdigit(): | ||
module = module[int(part)] | ||
else: | ||
module = getattr(module, part) | ||
if parts[-1].isdigit(): | ||
module[int(parts[-1])] = new_module | ||
else: | ||
setattr(module, parts[-1], new_module) | ||
|
||
|
||
def get_peft_model(model, linear_layers, freeze=True, verbose=True): | ||
source_model_trainable = count_parameters( | ||
model.language_model.trainable_parameters() | ||
) | ||
|
||
if freeze: | ||
freeze_model(model) | ||
|
||
for name, module in model.named_modules(): | ||
if isinstance(module, nn.Linear) and name.split(".")[-1] in linear_layers: | ||
lora_layer = LoRaLayer(module, 10, 0.1, 0.1) | ||
set_module_by_name(model, name, lora_layer) | ||
|
||
lora_model_trainable = count_parameters(model.language_model.trainable_parameters()) | ||
if verbose: | ||
print_trainable_parameters(source_model_trainable, lora_model_trainable) | ||
|
||
return model | ||
|
||
|
||
def freeze_model(model): | ||
for name, module in model.named_modules(): | ||
if name in [ | ||
"language_model", | ||
"vision_model", | ||
"vision_tower", | ||
"aligner", | ||
"connector", | ||
"multi_modal_projector", | ||
"mm_projector", | ||
]: | ||
model[f"{name}"].freeze() | ||
|
||
|
||
def find_all_linear_names(model): | ||
cls = nn.Linear | ||
lora_module_names = set() | ||
multimodal_keywords = [ | ||
"mm_projector", | ||
"vision_tower", | ||
"vision_resampler", | ||
"aligner", | ||
] | ||
for name, module in model.named_modules(): | ||
if any(mm_keyword in name for mm_keyword in multimodal_keywords): | ||
continue | ||
if isinstance(module, cls): | ||
names = name.split(".") | ||
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) | ||
|
||
if "lm_head" in lora_module_names: # needed for 16-bit | ||
lora_module_names.remove("lm_head") | ||
return list(lora_module_names) | ||
|
||
|
||
def collate_fn(processor, examples): | ||
texts = ["answer " + example["question"] for example in examples] | ||
labels = [example["multiple_choice_answer"] for example in examples] | ||
images = [example["image"].convert("RGB") for example in examples] | ||
tokens = processor( | ||
text=texts, | ||
images=images, | ||
suffix=labels, | ||
return_tensors="np", | ||
padding="longest", | ||
tokenize_newline_separately=False, | ||
) | ||
|
||
tokens = tokens.to(mx.float16) | ||
return tokens | ||
|
||
|
||
def flatten_dict(dd, separator="_", prefix=""): | ||
return ( | ||
{ | ||
prefix + separator + k if prefix else k: v | ||
for kk, vv in dd.items() | ||
for k, v in flatten_dict(vv, separator, kk).items() | ||
} | ||
if isinstance(dd, dict) | ||
else {prefix: dd} | ||
) | ||
|
||
|
||
def count_parameters(trainable_params_dict): | ||
total_params = 0 | ||
for k, v in flatten_dict(trainable_params_dict).items(): | ||
if hasattr(v, "shape"): | ||
total_params += np.prod(v.shape) | ||
|
||
if isinstance(v, list): | ||
for v_ in v: | ||
v_ = flatten_dict(v_) | ||
if isinstance(v_, dict): | ||
total_params += sum( | ||
np.prod(p.shape) for p in v_.values() if hasattr(p, "shape") | ||
) | ||
|
||
return total_params | ||
|
||
|
||
def print_trainable_parameters(source_model_trainable, lora_model_trainable): | ||
lora_trainable_percent = (lora_model_trainable / source_model_trainable) * 100 | ||
print( | ||
f"#trainable params: {lora_model_trainable} || all params: {source_model_trainable} || trainable%: {lora_trainable_percent}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
mlx>=0.14 | ||
mlx>=0.14 | ||
numpy | ||
transformers>=4.39.3 | ||
scipy==1.13.1 | ||
|