We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Docs: https://pytorch.org/docs/2.4/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer
- optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) + optimizer = ZeroRedundancyOptimizer( + model.parameters(), + optimizer_class=torch.optim.AdamW, + lr=args.lr, + fused=True + )
Very easy to use and immediately reduces memory usage.
The text was updated successfully, but these errors were encountered:
This also needs some updates to saving checkpoints:
if state["global_step"] % args.ckpt_freq == 0: + optimizer.consolidate_state_dict(to=0) if rank == 0: torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")
However, HUGE CAVEAT:
The consolidate_state_dict transfers between single pair of GPUs at a time. It is VERY slow with llama 8B (taking minutes per GPU).
Not sure if should be recommended for this reason.
Sorry, something went wrong.
#44 Adding ZeroRedundancyOptimizer to ch 2,3
2c7401e
Successfully merging a pull request may close this issue.
Docs: https://pytorch.org/docs/2.4/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer
Very easy to use and immediately reduces memory usage.
The text was updated successfully, but these errors were encountered: