-
Notifications
You must be signed in to change notification settings - Fork 252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient checkpointing for grad_weight
in LFCE
#533
Comments
I wrote a fully recomputed version of LFCE which resolves this issue. Would other people want it? I guess not many are training big models on long context. Though doing checkpointing would be better for this case and most cases. |
Hi @cassanof this sounds reasonable and seems that it can further reduce the memory. Have you measured the speed as it may introduce some regression? We greatly appreciate if you are willing to contribute the code. Some users do raise similar requests |
Hi! Recomputation definitely regresses throughput, it's only useful if you are doing pipeline parallelism at a high degree or you are really really desperate for memory. |
Cool! Feel free to share us with a draft. I can find folks to polish and merge into liger. cc @austin362667 @Tcc0403 |
🚀 The feature, motivation and pitch
The LFCE kernel allocates a
grad_weight
tensor:Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py
Line 47 in a8fa3bb
This tensor then gets updated throughout the chunked loss calculation and finally used in the backward as a custom grad operation:
Liger-Kernel/src/liger_kernel/ops/fused_linear_cross_entropy.py
Lines 127 to 136 in a8fa3bb
This has shape
[vocab_size, hidden_size]
, which in situations where you have big models with big vocabularies, this becomes very large, which makes it impossible to do large pipeline parallel microbatches at long sequence lengths, as I have to keep this tensor in memory until the backward. It would be great to have gradient check-pointing here or even full recomputation would work.Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: