Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging lora adapter with Llama 3.2 vision #702

Open
2 tasks
tymoma01 opened this issue Oct 3, 2024 · 5 comments
Open
2 tasks

Merging lora adapter with Llama 3.2 vision #702

tymoma01 opened this issue Oct 3, 2024 · 5 comments
Assignees

Comments

@tymoma01
Copy link

tymoma01 commented Oct 3, 2024

System Info

CUDA Version: 12.4
GPU: A6000

Information

  • The official example scripts
  • My own modified scripts

🐛 Describe the bug

After finetuning Llama3.2 vision using FSDP + peft LoRA with this command:

torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 2 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding  --use_peft --peft_method lora

A folder at PATH/to/save/PEFT/model is created containing:

ls
README.md  adapter_config.json  adapter_model.safetensors

I want to merge the adapter with the base model for inference.
To do that I used this code:

import os
import json
from peft import PeftModel
import torch
from transformers import MllamaForConditionalGeneration, BitsAndBytesConfig

BASE_MODEL = "meta-llama/Llama-3.2-11B-Vision-Instruct"
OUTPUT_PATH = "./model"
LORA_PATH = "./PATH/to/save/PEFT/model"
os.makedirs(OUTPUT_PATH, exist_ok=True)

base_model = MllamaForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
    device_map="cuda"
)


model_to_merge = PeftModel.from_pretrained(base_model, LORA_PATH)
merged_model = model_to_merge.merge_and_unload()
merged_model.save_pretrained(OUTPUT_PATH)

# same behaviour when doing:
# base_model.load_adapter(lora_adapter_path)
# base_model.enable_adapters()

However, this code generates only 3 safetensors files in the output folder, whereas the base model originally had 5:

ls
config.json             model-00001-of-00003.safetensors  model-00003-of-00003.safetensors
generation_config.json  model-00002-of-00003.safetensors  model.safetensors.index.json

Error logs

When trying to run inference on this merged model using:
python multi_modal_infer.py --image_path "<image_path>" --prompt_text "Describe this image" --temperature 0.1 --top_p 0.8 --model_name ./model --hf_token <hf_token>

I encounter the following error:
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [92,0,0], thread: [0,0,0] Assertion -sizes[i] <= index && index < sizes[i] && "index out of bounds" failed.

Full Traceback:

Traceback (most recent call last):
  File "/home/finetuning/llama-recipes/inference_finetuned_model.py", line 34, in <module>
    output = model.generate(**inputs, temperature=0.1, top_p=0.8, max_new_tokens=512)
  File "/home/florence2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/transformers/generation/utils.py", line 2048, in generate
    result = self._sample(
  File "/home/florence2/lib/python3.10/site-packages/transformers/generation/utils.py", line 3008, in _sample
    outputs = self(**model_inputs, return_dict=True)
  File "/home/florence2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py", line 2171, in forward
    cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
  File "/home/florence2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/florence2/lib/python3.10/site-packages/bitsandbytes/nn/modules.py", line 1009, in forward
    out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
  File "/home/florence2/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 556, in matmul
    return MatMul8bitLt.apply(A, B, out, bias, state)
  File "/home/florence2/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/florence2/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py", line 382, in forward
    CAt[:, state.idx.long()] = 0
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Expected behavior

Has anyone encountered this error while merging LoRA adapters for inference? Is this a tensor size mismatch issue or a problem with quantization (BitsAndBytes)? What might cause the reduced number of safetensors files, and how could I solve this?

@wukaixingxp
Copy link
Contributor

Hi! Can you show me what is inside of model.safetensors.index.json? you can compare it to the original one here to see if something is missing. As long as the model.safetensors.index.json contains all the layers, our script should be able to load it. Meanwhile, I think you should use bfloat16 for all the weights during merging. You can use CPU RAM to avoid GPU OOM. Maybe something like:

base_model = MllamaForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="cpu"
)

@wukaixingxp
Copy link
Contributor

Lastly, RuntimeError: CUDA error: device-side assert triggered hides the actually error message. You can use CPU inference to get the real error message.

@wukaixingxp wukaixingxp self-assigned this Oct 4, 2024
@marscod
Copy link

marscod commented Oct 6, 2024

I was able to merge it without any issue. Here is my code that might help:

from peft import PeftModel    
base_model = MllamaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained(model_id)

prompt = f"<|image|><|begin_of_text|>question:{question}"
inputs = processor(image, prompt, return_tensors="pt").to(basee_model.device)
output = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output[0]))

lora_model = PeftModel.from_pretrained(base_model, adapters_name)
model = lora_model.merge_and_unload()
processor = AutoProcessor.from_pretrained(model_id)
processor.bos_token_id = 1

inputs = processor(image, prompt, return_tensors="pt").to(m.device)
output2 = model.generate(**inputs, max_new_tokens=100)
print(processor.decode(output2[0]))

However, the inference result is identical to the base model. @wukaixingxp, any suggestion?

I also check there isn't any lora in lora_model.named_parameters():

lora_params = {n: p for n, p in lora_model.named_parameters() if "lora" in n}
for n, p in lora_params.items():
    print(n, p.sum())

@tymoma01
Copy link
Author

tymoma01 commented Oct 9, 2024

Hi! Can you show me what is inside of model.safetensors.index.json? you can compare it to the original one here to see if something is missing. As long as the model.safetensors.index.json contains all the layers, our script should be able to load it. Meanwhile, I think you should use bfloat16 for all the weights during merging. You can use CPU RAM to avoid GPU OOM. Maybe something like:

base_model = MllamaForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    device_map="cpu"
)

Thank you for getting back to me.
I tried using 'cpu' instead of 'gpu' and it worked well.

Do you happen to have a further explanation of why CPU must be used instead of GPU for this usage?
Is it just to avoid OOM issues (which seem to be silent on GPU)? Or is there something deeper happening?

@wukaixingxp
Copy link
Contributor

@tymoma01 I use cpu because it will print out the error trace easily (CUDA errors are sometime hard to track back to exactly which line) and also avoid OOM. Just a personal preference to deal with those one-time operation to make life easier.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants
@marscod @wukaixingxp @tymoma01 and others