Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recipe for fine-tuning with lower resolution images #714

Open
avaradarajanfigma opened this issue Oct 8, 2024 · 8 comments
Open

Recipe for fine-tuning with lower resolution images #714

avaradarajanfigma opened this issue Oct 8, 2024 · 8 comments
Assignees

Comments

@avaradarajanfigma
Copy link

🚀 The feature, motivation and pitch

Is fine-tuning the Vision models at a lower resolution supported? If so, can you please add a recipe for that(or add a note in recipes/quickstart/finetuning/finetune_vision_model.md for how to do that. I tried setting the size param in processor as

 processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name,
                                          size={"height": 336, "width": 336})

But this results in a tensor size mismatch error

RuntimeError: The size of tensor a (577) must match the size of tensor b (1601) at non-singleton dimension 2

Alternatives

No response

Additional context

No response

@wukaixingxp
Copy link
Contributor

@avaradarajanfigma Thanks for your question. I believe the processor will automatically do the resize process as shown here. You do not need to modify any code to get our vision model working with lower resolution images. Of course you can run fine-tuning on you own lower resolution images using finetune_vision_model.md, if you want. Let me know if you have any questions.

@wukaixingxp wukaixingxp self-assigned this Oct 8, 2024
@avaradarajanfigma
Copy link
Author

My question is, is there a way to initialize the model/processor in such a way that it doesn't scale up the lower resolution images first(resulting in fewer visual tokens), so that it is possible to fine tune with less memory?

@wukaixingxp
Copy link
Contributor

I do not think it is possible to feed lower resolution image into our vision model. (1) Processor will always resize as # do_resize=False is not supported stated here. (2) Even if somehow you bypass the processor, our model still can not take lower resolution image input directly. As for saving memory for fine-tuning, I recommend you use FSDP+peft. Let me know if you have more questions

@avaradarajanfigma
Copy link
Author

Thanks. What GPUs were the recipes/quickstart/finetuning/finetune_vision_model.md recipe tested on? I tried on 4 A10 GPUs and it OOMs.

@wukaixingxp
Copy link
Contributor

What is your command to run the fine-tune?

@avaradarajanfigma
Copy link
Author

I am using the same command as in recipes/quickstart/finetuning/finetune_vision_model.md, except I changed batch size to 1

torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 1 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding  --use_peft --peft_method lora

@wukaixingxp
Copy link
Contributor

Can I see the error message? I think Lora should work as estimated from this table. I was using the H100 to test the fine-tune but I will investigate more about this problem.

@avaradarajanfigma
Copy link
Author

Below are the OOM errors. Based on the model sizes in that table - they are not referring to the vision models.

torchrun --nnodes 1 --nproc_per_node 4  recipes/quickstart/finetuning/finetuning.py --enable_fsdp --lr 1e-5  --num_epochs 3 --batch_size_training 1 --model_name meta-llama/Llama-3.2-11B-Vision-Instruct --dist_checkpoint_root_folder ./finetuned_model --dist_checkpoint_folder fine-tuned  --use_fast_kernels --dataset "custom_dataset" --custom_dataset.test_split "test" --custom_dataset.file "recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py"  --run_validation True --batching_strategy padding  --use_peft --peft_method lora
W1008 13:22:01.883000 140676129310528 torch/distributed/run.py:757] 
W1008 13:22:01.883000 140676129310528 torch/distributed/run.py:757] *****************************************
W1008 13:22:01.883000 140676129310528 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W1008 13:22:01.883000 140676129310528 torch/distributed/run.py:757] *****************************************
Clearing GPU cache for all ranks
--> Running with torch dist debug set to detail
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.66it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.70it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.47it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.45it/s]
--> Model meta-llama/Llama-3.2-11B-Vision-Instruct

--> meta-llama/Llama-3.2-11B-Vision-Instruct has 10670.220835 Million params

trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
bFloat16 enabled for mixed precision - using bfSixteen policy
trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
trainable params: 5,898,240 || all params: 10,676,119,075 || trainable%: 0.0552
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> applying fsdp activation checkpointing...
--> Training Set Length = 1800
--> Validation Set Length = 200
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 450
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 450
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 450
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
length of dataset_train 1800
custom_data_collator is used
--> Num of Training Set Batches loaded = 450
--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:00<?, ?it/s]/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:00<?, ?it/s]--> Num of Validation Set Batches loaded = 50
--> Num of Validation Set Batches loaded = 50
Starting epoch 0/3
train_config.max_train_step: 0
/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/cuda/memory.py:330: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:00<?, ?it/s]/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py:91: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn(
/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py:91: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn(
/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py:91: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn(
/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py:91: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
  warnings.warn(
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/recipes/quickstart/finetuning/finetuning.py", line 8, in <module>
[rank0]:     fire.Fire(main)
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
[rank0]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
[rank0]:     component, remaining_args = _CallAndUpdateTrace(
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank0]:     component = fn(*varargs, **kwargs)
[rank0]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/finetuning.py", line 313, in main
[rank0]:     results = train(
[rank0]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/utils/train_utils.py", line 175, in train
[rank0]:     loss.backward()
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
[rank0]:     return user_fn(self, *args)
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 320, in backward
[rank0]:     torch.autograd.backward(outputs_with_grad, args_with_grad)
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 19.93 GiB. GPU 
[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/recipes/quickstart/finetuning/finetuning.py", line 8, in <module>
[rank2]:     fire.Fire(main)
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
[rank2]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
[rank2]:     component, remaining_args = _CallAndUpdateTrace(
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank2]:     component = fn(*varargs, **kwargs)
[rank2]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/finetuning.py", line 313, in main
[rank2]:     results = train(
[rank2]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/utils/train_utils.py", line 175, in train
[rank2]:     loss.backward()
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank2]:     torch.autograd.backward(
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank2]:     _engine_run_backward(
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank2]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
[rank2]:     return user_fn(self, *args)
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 320, in backward
[rank2]:     torch.autograd.backward(outputs_with_grad, args_with_grad)
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank2]:     _engine_run_backward(
[rank2]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank2]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank2]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 19.93 GiB. GPU  has a total capacity of 21.99 GiB of which 5.56 GiB is free. Including non-PyTorch memory, this process has 16.42 GiB memory in use. Of the allocated memory 10.76 GiB is allocated by PyTorch, and 5.21 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/recipes/quickstart/finetuning/finetuning.py", line 8, in <module>
[rank1]:     fire.Fire(main)
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
[rank1]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
[rank1]:     component, remaining_args = _CallAndUpdateTrace(
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank1]:     component = fn(*varargs, **kwargs)
[rank1]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/finetuning.py", line 313, in main
[rank1]:     results = train(
[rank1]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/utils/train_utils.py", line 175, in train
[rank1]:     loss.backward()
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
[rank1]:     return user_fn(self, *args)
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 320, in backward
[rank1]:     torch.autograd.backward(outputs_with_grad, args_with_grad)
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 19.93 GiB. GPU  has a total capacity of 21.99 GiB of which 5.56 GiB is free. Including non-PyTorch memory, this process has 16.42 GiB memory in use. Of the allocated memory 10.76 GiB is allocated by PyTorch, and 5.21 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
[rank3]: Traceback (most recent call last):
[rank3]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/recipes/quickstart/finetuning/finetuning.py", line 8, in <module>
[rank3]:     fire.Fire(main)
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 135, in Fire
[rank3]:     component_trace = _Fire(component, args, parsed_flag_args, context, name)
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 468, in _Fire
[rank3]:     component, remaining_args = _CallAndUpdateTrace(
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[rank3]:     component = fn(*varargs, **kwargs)
[rank3]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/finetuning.py", line 313, in main
[rank3]:     results = train(
[rank3]:   File "/home/ubuntu/llama-recipes-github-repo/llama-recipes/src/llama_recipes/utils/train_utils.py", line 175, in train
[rank3]:     loss.backward()
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank3]:     torch.autograd.backward(
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/function.py", line 301, in apply
[rank3]:     return user_fn(self, *args)
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 320, in backward
[rank3]:     torch.autograd.backward(outputs_with_grad, args_with_grad)
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank3]:     _engine_run_backward(
[rank3]:   File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank3]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank3]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 19.93 GiB. GPU  has a total capacity of 21.99 GiB of which 5.56 GiB is free. Including non-PyTorch memory, this process has 16.42 GiB memory in use. Of the allocated memory 10.76 GiB is allocated by PyTorch, and 5.21 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:15<?, ?it/s]
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:14<?, ?it/s]
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:14<?, ?it/s]
Training Epoch: 1:   0%|                                                                                                                                                                  | 0/450 [00:15<?, ?it/s]
E1008 13:22:46.931000 140676129310528 torch/distributed/elastic/multiprocessing/api.py:826] failed (exitcode: 1) local_rank: 0 (pid: 11317) of binary: /home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/bin/python
Traceback (most recent call last):
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/distributed/run.py", line 879, in main
    run(args)
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/distributed/run.py", line 870, in run
    elastic_launch(
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 132, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/ubuntu/.cache/pypoetry/virtualenvs/test-llama-recipes-K8dcJ3Te-py3.10/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 263, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
recipes/quickstart/finetuning/finetuning.py FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-10-08_13:22:46
  host      : ip-10-142-7-161.us-west-2.compute.internal
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 11318)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[2]:
  time      : 2024-10-08_13:22:46
  host      : ip-10-142-7-161.us-west-2.compute.internal
  rank      : 2 (local_rank: 2)
  exitcode  : 1 (pid: 11319)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
[3]:
  time      : 2024-10-08_13:22:46
  host      : ip-10-142-7-161.us-west-2.compute.internal
  rank      : 3 (local_rank: 3)
  exitcode  : 1 (pid: 11320)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-10-08_13:22:46
  host      : ip-10-142-7-161.us-west-2.compute.internal
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 11317)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants