-
Notifications
You must be signed in to change notification settings - Fork 227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
act checkpointing OOM, float8 causes CUDA memory allocation retries #56
Comments
Hi @akhauriyash! does activation checkpointing increase memory usage in your case? Does FP8 trigger memory reallocation? |
Apologies for the delayed response I had some runs I couldn't stop, so couldn't test it... [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 I don't think I have the hardware support to run fp8, hopefully somebody else helps out! 😅 |
any update on this? |
We are aware of this, thanks to @lw, and it’s related to PyTorch. The error has been reported to the PyTorch team, and we’ll try to provide further details as soon as they are available. |
thank you @mathuvu, is this related to any specific PyTorch version? Did it work for previous versions? Did it work on a different hardware? |
To my knowledge, there is no version of PyTorch that is supporting fp8 and selective checkpointing at the same time for now. |
@mathuvu but the errors are independent of each other. float8 (w/out act ckpt) causes an increase in CUDA memory usage and memory allocation retries. Were you able to support either of the two? |
Ah ok, sorry. I didn't test much selective checkpointing. I mostly use compile. For fp8, It must be used with compile otherwise it is slower. |
@mathuvu thank you for answering. I am using fp8 along with compile (see my issue, the only variable I changed in the yaml configuration is Do you encounter the same drawbacks? |
Yes, there is some overhead in using fp8 for now. Not sure, but if I remember correctly, pytorch keep a copy of fp8 and bf16 weights of the model and that why you get cuda reallocation error. Using more gpus along with FSDP reduce this issue, but it is clearly suboptimal. So, The smaller wps in fp8, in your case, can be explained by the |
Thank you, that's very helpful. One more question. Is the fp8 strategy in Lingua the same as the one in torchtitan? Using the same hardware and model size in titan, I didn't encounter any memory issue with fp8. |
I'm not 100% sure, but I think torchtitan use tensorwise scaling while we are using rowwise scaling for fp8 with torch._scaled_mm. |
Thank you @mathuvu, I'll look into that. Regarding the issue with So the issue is not related to using a large model on a limited memory hardware. |
Concerning fp8 issues, you can find more context in pytorch/pytorch#141881 |
I just submitted some fixes for fp8 in #63. These should help both with memory usage (thanks to a workaround to the PyTorch issue mentioned above) and with precision (because we disable fast_accum in the bwd). Let me know if you can try those and if they help! |
Hi, I encountered the same issue when setting selective_activation_checkpointing to True (with BF16), resulting in an OOM error. |
Have you tried with the fix I submitted as part of #63? |
Hi! Thanks for your prompt reply. I merged this part, but it still failed. My case isn't related to FP8; enabling
|
I confirm that:
|
I am trying to train LLama-7B on 8xH100-80GB (HBM3),
Baseline
When running without activation checkpointing and without fp8, everything runs smoothly:
filling up 80% of memory and achieving ~9000 WPS:
Activation checkpointing OOM
When setting
selective_activation_checkpointing=true
, I hit CUDA OOM.Float8
When setting
float8_recipe: rowwise
(and turning off act ckpting), I getCUDA memory allocation retries
warnin, a higher memory usage (87%) and lower throughput of 2000 WPS:Has someone encountered these errors? I am using
apps/main/configs/llama_7B.yaml
without any major modification.The same warning and error occur also in much smaller models on a single GPU without DDP nor FSDP on.
I am attaching the traces for rank 0:
baseline.log
act_ckpt.log
fp8.log
System
python: 3.11.10
torch: 2.5.0+cu121
The text was updated successfully, but these errors were encountered: