Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeroRedundancyOptimizer to chapters 2 & 3 #44

Open
corey-lambda opened this issue Oct 21, 2024 · 1 comment · May be fixed by #45
Open

Add ZeroRedundancyOptimizer to chapters 2 & 3 #44

corey-lambda opened this issue Oct 21, 2024 · 1 comment · May be fixed by #45

Comments

@corey-lambda
Copy link
Contributor

Docs: https://pytorch.org/docs/2.4/distributed.optim.html#torch.distributed.optim.ZeroRedundancyOptimizer

-    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
+    optimizer = ZeroRedundancyOptimizer(
+        model.parameters(),
+        optimizer_class=torch.optim.AdamW,
+        lr=args.lr,
+        fused=True
+    )

Very easy to use and immediately reduces memory usage.

@corey-lambda
Copy link
Contributor Author

corey-lambda commented Oct 21, 2024

This also needs some updates to saving checkpoints:

 if state["global_step"] % args.ckpt_freq == 0:
+    optimizer.consolidate_state_dict(to=0)
     if rank == 0:
         torch.save(optimizer.state_dict(), exp_dir / "optimizer.pt")

However, HUGE CAVEAT:

The consolidate_state_dict transfers between single pair of GPUs at a time. It is VERY slow with llama 8B (taking minutes per GPU).

Not sure if should be recommended for this reason.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant