Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix serialization of state dicts during TRAC/OrthoGrad checkpointing #329

Closed
wants to merge 5 commits into from

Conversation

Vectorrent
Copy link
Contributor

Problem (Why?)

Serialization fails during save/load of OrthoGrad:

Traceback (most recent call last):
  File "/home/crow/repos/praxis/run.py", line 1130, in <module>
    trainer.fit(
    ~~~~~~~~~~~^
        train_model,
        ^^^^^^^^^^^^
        datamodule,
        ^^^^^^^^^^^
        ckpt_path=ckpt_path,
        ^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 539, in fit
    call._call_and_handle_interrupt(
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
        self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 575, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
    ~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 982, in _run
    results = self._run_stage()
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1026, in _run_stage
    self.fit_loop.run()
    ~~~~~~~~~~~~~~~~~^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
    self.advance()
    ~~~~~~~~~~~~^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
    self.epoch_loop.run(self._data_fetcher)
    ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 150, in run
    self.advance(data_fetcher)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 339, in advance
    call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
    ~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/call.py", line 222, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
    ~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/run.py", line 879, in on_train_batch_end
    self._save_topk_checkpoint(trainer, monitor_candidates)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 385, in _save_topk_checkpoint
    self._save_monitor_checkpoint(trainer, monitor_candidates)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 705, in _save_monitor_checkpoint
    self._update_best_and_save(current, trainer, monitor_candidates)
    ~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 757, in _update_best_and_save
    self._save_checkpoint(trainer, filepath)
    ~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
    trainer.save_checkpoint(filepath, self.save_weights_only)
    ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/trainer.py", line 1366, in save_checkpoint
    checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
    optimizer_state = trainer.strategy.optimizer_state(optimizer)
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/lightning/pytorch/strategies/strategy.py", line 190, in optimizer_state
    return optimizer.state_dict()
           ~~~~~~~~~~~~~~~~~~~~^^
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/torch/_compile.py", line 32, in inner
    return disable_fn(*args, **kwargs)
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
    return fn(*args, **kwargs)
  File "/home/crow/repos/praxis/.venv/lib/python3.13/site-packages/torch/optim/optimizer.py", line 687, in state_dict
    for pre_hook in self._optimizer_state_dict_pre_hooks.values():
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'OrthoGrad' object has no attribute '_optimizer_state_dict_pre_hooks'. Did you mean: '_optimizer_step_pre_hooks'?

Solution (What/How?)

I added the missing hooks. Though, they should probably be added to BaseOptimizer directly.

Copy link

codecov bot commented Jan 22, 2025

Codecov Report

Attention: Patch coverage is 37.14286% with 22 lines in your changes missing coverage. Please review.

Project coverage is 99.74%. Comparing base (cdbd9bc) to head (1522dd8).

Files with missing lines Patch % Lines
pytorch_optimizer/optimizer/trac.py 29.62% 19 Missing ⚠️
pytorch_optimizer/optimizer/orthograd.py 62.50% 3 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##              main     #329      +/-   ##
===========================================
- Coverage   100.00%   99.74%   -0.26%     
===========================================
  Files          107      107              
  Lines         8439     8468      +29     
===========================================
+ Hits          8439     8446       +7     
- Misses           0       22      +22     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Vectorrent Vectorrent changed the title Fix saving/loading of OrthoGrad state dicts during checkpointing Fix serialization of state dicts during TRAC/OrthoGrad checkpointing Jan 23, 2025
@Vectorrent
Copy link
Contributor Author

I ran into the exact same problem with TRAC, and fixed it in the same way. Tested Lookahead as well, it doesn't have this problem.

Copy link
Owner

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi! thanks for your PR.

I think it'd be good to create __getstate__, which I deleted by mistake, state_dict() and load_state_dict() methods to save/load the optimizer state instead of adding *_hooks members.

I'll work on it maybe later next week.

@Vectorrent
Copy link
Contributor Author

Based on your advice, I made an attempt to add those methods myself. I tested the saving/loading of both wrappers, and they seem to work. How does it look to you?

Copy link
Owner

@kozistr kozistr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

umm, I'm thinking of a kinda slightly different way with yours, and also gonna add more test cases to prevent the state_dict issues.

so, I'll make another PR to correct this, and thanks for raising the issue and your work!

@kozistr kozistr closed this Jan 25, 2025
@kozistr kozistr mentioned this pull request Jan 25, 2025
1 task
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants