Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

act checkpointing OOM, float8 causes CUDA memory allocation retries #56

Open
Niccolo-Ajroldi opened this issue Nov 15, 2024 · 20 comments
Open

Comments

@Niccolo-Ajroldi
Copy link

Niccolo-Ajroldi commented Nov 15, 2024

I am trying to train LLama-7B on 8xH100-80GB (HBM3),

Baseline

When running without activation checkpointing and without fp8, everything runs smoothly:

distributed:
  fsdp_type: full_shard
  compile: true
  model_dtype: bf16
  matmul_allow_tf32: false
  selective_activation_checkpointing: false
  tp_size: 1
  float8_recipe: null

filling up 80% of memory and achieving ~9000 WPS:

0: INFO    24-11-14 19:27:00.646661 - 0:01:24 - step: 10  acc: 0  loss: 11.0584  grad: 2.01e+01  flops: 4.50e+14  wps: 9.08e+03  iter:  0.9013  data: 0.0004  lr: 4.50e-06  mem: 80%  pow: 672.486 W

Activation checkpointing OOM

When setting selective_activation_checkpointing=true, I hit CUDA OOM.

Float8

When setting float8_recipe: rowwise (and turning off act ckpting), I get CUDA memory allocation retries warnin, a higher memory usage (87%) and lower throughput of 2000 WPS:

0: WARNING 24-11-14 19:35:09.477835 - 0:02:39 - 1 CUDA memory allocation retries.
0: INFO    24-11-14 19:35:40.418971 - 0:03:10 - step: 10  acc: 0  loss: 11.0449  grad: 1.94e+01  flops: 9.97e+13  wps: 2.01e+03  iter:  3.8529  data: 0.0008  lr: 4.50e-06  mem: 87%  pow: 468.933 W

Has someone encountered these errors? I am using apps/main/configs/llama_7B.yaml without any major modification.
The same warning and error occur also in much smaller models on a single GPU without DDP nor FSDP on.

I am attaching the traces for rank 0:
baseline.log
act_ckpt.log
fp8.log

System

python: 3.11.10
torch: 2.5.0+cu121

@akhauriyash
Copy link

akhauriyash commented Nov 15, 2024

This might be a bit off-topic, but I also get this warning


/home/najroldi/miniconda3/envs/lingua/lib/python3.11/site-packages/torch/autograd/graph.py:825: UserWarning: cuDNN SDPA backward got grad_output.strides() != output.strides(), attempting to materialize a grad_output with matching strides... (Triggered internally at ../aten/src/ATen/native/cudnn/MHA.cpp:674.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

Should I be worried wrt speed of execution, would appreciate if somebody knows.

Also,
image
I trained a much smaller transformer (fineweb-edu dataset, 60M params) and I see that the grad norm falls initially, then towards the end starts increasing. This is reflected in wikitext2 perplexity, which goes down, then starts increasing... Wondering if you see the same in your 7B model (once you are able to run it), and are you using the default llama_7b config file with only 100k steps with batch size 2, 4096 tokens? That's 819,200,000 (800M tokens) only.

@Niccolo-Ajroldi
Copy link
Author

Hi @akhauriyash! does activation checkpointing increase memory usage in your case? Does FP8 trigger memory reallocation?

@akhauriyash
Copy link

Apologies for the delayed response I had some runs I couldn't stop, so couldn't test it...

[rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89

I don't think I have the hardware support to run fp8, hopefully somebody else helps out! 😅

@Niccolo-Ajroldi
Copy link
Author

any update on this?

@mathuvu
Copy link
Contributor

mathuvu commented Nov 25, 2024

We are aware of this, thanks to @lw, and it’s related to PyTorch. The error has been reported to the PyTorch team, and we’ll try to provide further details as soon as they are available.

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Nov 25, 2024

thank you @mathuvu, is this related to any specific PyTorch version? Did it work for previous versions? Did it work on a different hardware?

@mathuvu
Copy link
Contributor

mathuvu commented Nov 25, 2024

To my knowledge, there is no version of PyTorch that is supporting fp8 and selective checkpointing at the same time for now.

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Nov 25, 2024

@mathuvu but the errors are independent of each other.

float8 (w/out act ckpt) causes an increase in CUDA memory usage and memory allocation retries.
Activation checkpointing (w/out fp8) hits CUDA OOM.

Were you able to support either of the two?

@mathuvu
Copy link
Contributor

mathuvu commented Nov 25, 2024

Ah ok, sorry. I didn't test much selective checkpointing. I mostly use compile. For fp8, It must be used with compile otherwise it is slower.

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Nov 25, 2024

@mathuvu thank you for answering.

I am using fp8 along with compile (see my issue, the only variable I changed in the yaml configuration is selective_activation_checkpointing: true), yet the CUDA mem reallocation, higher mem usage and lower WPS.

Do you encounter the same drawbacks?

@mathuvu
Copy link
Contributor

mathuvu commented Nov 25, 2024

Yes, there is some overhead in using fp8 for now. Not sure, but if I remember correctly, pytorch keep a copy of fp8 and bf16 weights of the model and that why you get cuda reallocation error. Using more gpus along with FSDP reduce this issue, but it is clearly suboptimal. So, The smaller wps in fp8, in your case, can be explained by the cuda memory reallocation retries

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Nov 25, 2024

Thank you, that's very helpful. One more question. Is the fp8 strategy in Lingua the same as the one in torchtitan? Using the same hardware and model size in titan, I didn't encounter any memory issue with fp8.

@mathuvu
Copy link
Contributor

mathuvu commented Nov 25, 2024

I'm not 100% sure, but I think torchtitan use tensorwise scaling while we are using rowwise scaling for fp8 with torch._scaled_mm.

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Nov 25, 2024

Thank you @mathuvu, I'll look into that.

Regarding the issue with fp8, I noticed that I encounter the same issues (CUDA memory allocation retries warning, a higher memory usage and lower throughput) also when using a smaller model and even without DDP/FSDP.

So the issue is not related to using a large model on a limited memory hardware.

@mathuvu
Copy link
Contributor

mathuvu commented Dec 11, 2024

Concerning fp8 issues, you can find more context in pytorch/pytorch#141881

@lw
Copy link
Contributor

lw commented Dec 11, 2024

I just submitted some fixes for fp8 in #63. These should help both with memory usage (thanks to a workaround to the PyTorch issue mentioned above) and with precision (because we disable fast_accum in the bwd). Let me know if you can try those and if they help!

@HaozheLiu-ST
Copy link

Hi, I encountered the same issue when setting selective_activation_checkpointing to True (with BF16), resulting in an OOM error.

@lw
Copy link
Contributor

lw commented Dec 16, 2024

Have you tried with the fix I submitted as part of #63?

@HaozheLiu-ST
Copy link

Hi! Thanks for your prompt reply. I merged this part, but it still failed. My case isn't related to FP8; enabling selective_activation_checkpointing=True alone triggers an OOM error.
My configuration is like this:

distributed:
    fsdp_type: full_shard
    dp_shard: 4
    dp_replicate: 1
    compile: True
    model_dtype: bf16
    matmul_allow_tf32: false
    selective_activation_checkpointing: True
    tp_size: 1
    compile_cache_size_limit: 64

@Niccolo-Ajroldi
Copy link
Author

Niccolo-Ajroldi commented Dec 17, 2024

I confirm that:

  1. The fix in Update float8 recipe #63 solves the fp8 bug. Now fp8 runs smoothly with higher throughput and the same memory allocation. Thank you @lw! ✅
  2. Activation checkpointing still hits CUDA OOM on the 7B model (after merging Update float8 recipe #63), as also reported by @HaozheLiu-ST. ❌

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants