-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Store SOAP condition matrices as the dtype of their parameters #335
Conversation
should now comply with the formatting requirements :) |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #335 +/- ##
==========================================
Coverage 100.00% 100.00%
==========================================
Files 108 110 +2
Lines 8509 8731 +222
==========================================
+ Hits 8509 8731 +222 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your fix! I left some minor reviews, and others look great to me :)
--- updated
I just committed them to merge this PR
After #333 SOAP functions correctly, but it has significant excess VRAM usage when training models with reduced precision weights (e.g.
bfloat16
).This PR initializes and updates the condition matrices based on the precision of the parameters themselves rather than defaulting to
float32
for everything.Note that the only exception to this is the QR factorization to get the orthogonal Q is done in
float32
, regardless of the underlying matrix precision, asfloat32
has CUDA kernel support as of PyTorch 2.5.1