Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP finetuned model inference question #634

Open
mathmax12 opened this issue Aug 15, 2024 · 19 comments
Open

FSDP finetuned model inference question #634

mathmax12 opened this issue Aug 15, 2024 · 19 comments
Assignees

Comments

@mathmax12
Copy link

mathmax12 commented Aug 15, 2024

🚀 The feature, motivation and pitch

The fine-tuning with only FSDP works well and sharded checkpoints are saved as __0_*.distcp, .metadata, and train_params.yaml. I can see the loss drop reasonably. Here is the training command:
torchrun --nnodes 1 --nproc_per_node 8 ./recipes/quickstart/finetuning/finetuning.py --model_name /tmp/llama-recipes/Meta-Llama-3.1-8B --output_dir ./fsdp_fine_tune_results/output_model_1_8 --dist_checkpoint_root_folder ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8 --enable_fsdp --num_epochs 1 --batch_size_training 2 --dataset alpaca_dataset

Then I tried to do the inference with the FSDP checkpoints by:

  1. convert FSDP checkpoint to hf:
    python ./src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-/tmp/llama-recipes/Meta-Llama-3.1-8B --consolidated_model_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --HF_model_path_or_name /tmp/llama-recipes/Meta-Llama-3.1-8B
  2. inference with the inference.py:
    python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt

But I got zero outputs:
"llama-recipes# python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt

/root/llama-recipes/src/llama_recipes/model_checkpointing/checkpoint_handler.py:17: DeprecationWarning: torch.distributed._shard.checkpoint will be deprecated, use torch.distributed.checkpoint instead
from torch.distributed._shard.checkpoint import (
use_fast_kernelsFalse
Using the SDPA attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends.
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:14<00:00, 2.13s/it]
User prompt deemed safe.
User prompt:
I have tomatoes, basil and cheese at home. What can I cook for dinner?\n
Setting pad_token_id to eos_token_id:128001 for open-end generation.
the inference time is 286928.2311500283 ms
User input and model output deemed safe.
Model output:
I have tomatoes, basil and cheese at home. What can I cook for dinner?\n`
"

If I use the original Meta-Llama-3.1-8B model for the inference I can find the output is ok. Also when using the checkpoint from fine-tuning with FSDP + peft lora, the inference looks fine.

Could someone let me know if I missed anything? or Is there a way/tool to check if the FSDP checkpoints to HF checkpoints conversion goes well?

Thanks!

Alternatives

No response

Additional context

No response

@mathmax12
Copy link
Author

I also noticed that because of the all-reduce before the forward pass, it's not recommended to use FSDP for inference. Does this mean FSDP inference isn't supported so far by Pytorch or it's just not recommended because all-reduce will make the FSDP inference inefficient? If it is the second case is there an example I can follow to use FSDP checkpoints for inference?
Thanks.

@mreso
Copy link
Contributor

mreso commented Aug 19, 2024

Hi @mathmax12 thanks for reporting this, I was able to reproduce this and will be having a look.

@mathmax12
Copy link
Author

@mreso Thank you for looking into this issue. In the meantime, is there a workaround for using finetuned FSDP checkpoints for inference? Thanks.

@mathmax12
Copy link
Author

@mreso Is there an update on this? Thanks

@mreso
Copy link
Contributor

mreso commented Aug 28, 2024

Sorry @mathmax12, did not yet get the chance to look deeper into this.

@mathmax12
Copy link
Author

Could we prioritize this? if the checkpoints don't work how can we use the fine-tuned FSDP checkpoint for inference?

@bigtree2020
Copy link

Facing a similar issue. Is there a solution for this?

@mathmax12
Copy link
Author

Hey, @mreso I found this only happens for llama3 and 3.1 models. inference with checkpoints from FSDP llama2 is ok.
the arch of llama3 and llama2 are pretty similar. do you have any idea what may cause this issue?

@HamidShojanazeri
Copy link
Contributor

cc @mreso @wukaixingxp

@wukaixingxp
Copy link
Contributor

wukaixingxp commented Oct 8, 2024

@mreso @HamidShojanazeri I noticed that the alpaca dataset has not been updated as the label token is still -1 instead of -100 as shown here. Edit: it is because we did not add llama 3 special tokens

@mathmax12
Copy link
Author

Is there a fix to this?

@aishwaryap
Copy link

@mreso any updates on this? I am also facing a similar issue.
In my case I started with the meta-llama/Llama-3.1-8B-Instruct model, which if I directly perform HF inference with, it is not bad.
Similar to @mathmax12 I notice that when I converted the FSDP checkpoint to HF format using checkpoint_converter_fsdp_hf.py and try HF inference with it, I get empty output.
A PEFT checkpoint trained with same data starting with the same ckpt gives reasonable output.

@wukaixingxp
Copy link
Contributor

@aishwaryap @mathmax12 Thanks for reporting this bug. We just added a PR to add llama 3 support for alpaca dataset fine-tuning. Please give it a try and let me know if this helps.

@mathmax12
Copy link
Author

Thank @wukaixingxp for fixing this.
For some reason I got this issue rocessing dataset: 0%| | 0/49402 [00:00<?, ?it/s] Preprocessing dataset: 0%| | 0/49402 [00:00<?, ?it/s] 0: [rank1]: Traceback (most recent call last): 0: [rank1]: File "/root/./recipes/quickstart/finetuning/finetuning.py", line 8, in <module> 0: [rank1]: fire.Fire(main) 0: [rank1]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fire/core.py", line 135, in Fire 0: [rank1]: component_trace = _Fire(component, args, parsed_flag_args, context, name) 0: [rank1]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fire/core.py", line 468, in _Fire 0: [rank1]: component, remaining_args = _CallAndUpdateTrace( 0: [rank1]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace 0: [rank1]: component = fn(*varargs, **kwargs) 0: [rank1]: File "/root/src/llama_recipes/finetuning.py", line 323, in main 0: [rank1]: dataset_train = ConcatDataset( 0: [rank1]: File "/root/src/llama_recipes/data/concatenator.py", line 23, in __init__ 0: [rank1]: for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): 0: [rank1]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/tqdm/std.py", line 1181, in __iter__ 0: [rank1]: for obj in iterable: 0: [rank1]: File "/root/src/llama_recipes/datasets/alpaca_dataset.py", line 58, in __getitem__ 0: [rank1]: dialog_tokens = self.tokenizer.apply_chat_template(dialog) 0: [rank1]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 1801, in apply_chat_template 0: [rank1]: chat_template = self.get_chat_template(chat_template, tools) 0: [rank1]: File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/tokenization_utils_base.py", line 1962, in get_chat_template 0: [rank1]: raise ValueError( 0: [rank1]: ValueError: Cannot use chat template functions because tokenizer.chat_template is not set and no template argument was passed! For information about writing templates and setting the tokenizer.chat_template attribute, please see the documentation at https://huggingface.co/docs/transformers/main/en/chat_templating 0: --> applying fsdp activation checkpointing...

Here is what I did
git clone https://github.com/meta-llama/llama-recipes.git git checkout fix_alpaca installation ..... torchrun --nnodes 1 --nproc_per_node 8 ./recipes/quickstart/finetuning/finetuning.py --model_name ./Meta-Llama-3.1-8B --output_dir ./fsdp_fine_tune_results/output_model_1_8_3.1-8B --dist_checkpoint_root_folder ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_3.1-8B --enable_fsdp --num_epochs 1 --batch_size_training 16 --dataset alpaca_dataset --save_model True --context_length 4096
Please let me know if I missed anything.
Thanks

@wukaixingxp
Copy link
Contributor

@mathmax12 please use meta-llama/Llama-3.1-8B-Instruct instead of base model ./Meta-Llama-3.1-8B. Let me know if you have more questions

@mathmax12
Copy link
Author

@wukaixingxp
I tried both of the meta-llama/Meta-Llama-3-8B-Instruct and meta-llama/Llama-3.1-8B-Instruct. both of them have the same issue as before. Here is my steps:
image

python ./src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_8B/fine-tuned-Meta-Llama-3-8B-Instruct --consolidated_model_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf 

 python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt

root@0aed2c5417f3:~/workspace/blog/llama-recipes# git branch
* fix_alpaca
  main

I am also curious what is the different between the Meta-Llama-3-8B-Instruct and Meta-Llama-3-8B.
Thanks

@wukaixingxp
Copy link
Contributor

wukaixingxp commented Oct 24, 2024

@mathmax12 Meta-Llama-3-8B-Instruct is the one you should use for chat, see comment here. Your model is actually output something, but somehow it is not related to the topic. Run the follow steps may help:

torchrun --nnodes 1 --nproc_per_node 8    ./recipes/quickstart/finetuning/finetuning.py  --model_name meta-llama/Meta-Llama-3.1-8B-Instruct --output_dir ./fsdp_fine_tune_results/output_model_1_8 --dist_checkpoint_root_folder ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8  --enable_fsdp  --num_epochs 1 --batch_size_training 2 --dataset alpaca_dataset

python ./src/llama_recipes/inference/checkpoint_converter_fsdp_hf.py --fsdp_checkpoint_path ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8/fine-tuned-meta-llama/Meta-Llama-3.1-8B-Instruct/ --consolidated_model_path ./fsdp_fine_tune_results/fsdp_model_finetune
d_1_8_hf

python ./recipes/quickstart/inference/local_inference/inference.py --model_name ./fsdp_fine_tune_results/fsdp_model_finetuned_1_8_hf --prompt_file prompt_for_test.txt

@mathmax12
Copy link
Author

@wukaixingxp Thanks for the updates. I can't find the meta-llama/Meta-Llama-3.1-8B-Instruct model card on hugging face. But there is https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct. I wonder are they the same model with different names?

@wukaixingxp
Copy link
Contributor

yes, meta-llama/Meta-Llama-3.1-8B-Instruct has been renamed to meta-llama/Llama-3.1-8B-Instruct

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants