Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient checkpointing for grad_weight in LFCE #533

Open
cassanof opened this issue Jan 20, 2025 · 4 comments
Open

Gradient checkpointing for grad_weight in LFCE #533

cassanof opened this issue Jan 20, 2025 · 4 comments

Comments

@cassanof
Copy link

cassanof commented Jan 20, 2025

🚀 The feature, motivation and pitch

The LFCE kernel allocates a grad_weight tensor:

grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None

This tensor then gets updated throughout the chunked loss calculation and finally used in the backward as a custom grad operation:
torch.addmm(
input=grad_weight,
mat1=logits_chunk.t().to(
_input_chunk.dtype
), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error.
mat2=_input_chunk,
out=grad_weight,
alpha=1.0,
beta=1.0,
)

This has shape [vocab_size, hidden_size], which in situations where you have big models with big vocabularies, this becomes very large, which makes it impossible to do large pipeline parallel microbatches at long sequence lengths, as I have to keep this tensor in memory until the backward. It would be great to have gradient check-pointing here or even full recomputation would work.

Alternatives

No response

Additional context

No response

@cassanof
Copy link
Author

cassanof commented Jan 20, 2025

I wrote a fully recomputed version of LFCE which resolves this issue. Would other people want it? I guess not many are training big models on long context.

Though doing checkpointing would be better for this case and most cases.

@ByronHsu
Copy link
Collaborator

Hi @cassanof this sounds reasonable and seems that it can further reduce the memory. Have you measured the speed as it may introduce some regression? We greatly appreciate if you are willing to contribute the code. Some users do raise similar requests

@cassanof
Copy link
Author

cassanof commented Jan 21, 2025

Hi! Recomputation definitely regresses throughput, it's only useful if you are doing pipeline parallelism at a high degree or you are really really desperate for memory.
In my testing, it has regressed throughput by 15% with the same batch size. Though, due to the memory savings on the last rank, i can now 4x the batch size, which has given me a very large perf improvement! With pipeline parallelism, I can also balance the layers on the last rank to earlier ranks, so that the regression in throughput is fully amortized. Lmk if this addition is wanted, would probably need to be a separate loss function altogether.

@ByronHsu
Copy link
Collaborator

Cool! Feel free to share us with a draft. I can find folks to polish and merge into liger. cc @austin362667 @Tcc0403

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants