-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix serialization of state dicts during TRAC/OrthoGrad checkpointing #329
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #329 +/- ##
===========================================
- Coverage 100.00% 99.74% -0.26%
===========================================
Files 107 107
Lines 8439 8468 +29
===========================================
+ Hits 8439 8446 +7
- Misses 0 22 +22 ☔ View full report in Codecov by Sentry. |
I ran into the exact same problem with TRAC, and fixed it in the same way. Tested Lookahead as well, it doesn't have this problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi! thanks for your PR.
I think it'd be good to create __getstate__
, which I deleted by mistake, state_dict()
and load_state_dict()
methods to save/load the optimizer state instead of adding *_hooks
members.
I'll work on it maybe later next week.
Based on your advice, I made an attempt to add those methods myself. I tested the saving/loading of both wrappers, and they seem to work. How does it look to you? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umm, I'm thinking of a kinda slightly different way with yours, and also gonna add more test cases to prevent the state_dict
issues.
so, I'll make another PR to correct this, and thanks for raising the issue and your work!
Problem (Why?)
Serialization fails during save/load of OrthoGrad:
Solution (What/How?)
I added the missing hooks. Though, they should probably be added to
BaseOptimizer
directly.