Skip to content

Commit

Permalink
Set the state before saving "last" or "none" checkpoints (#11481)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
  • Loading branch information
2 people authored and lexierule committed Feb 9, 2022
1 parent bcd7c87 commit 9ebdc52
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Pin sphinx-autodoc-typehints with <v1.15 ([#11400](https://github.com/PyTorchLightning/pytorch-lightning/pull/11400))
- Skip testing with PyTorch 1.7 and Python 3.9 on Ubuntu ([#11217](https://github.com/PyTorchLightning/pytorch-lightning/pull/11217))
- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401))
- Fixed bug where the path for "last" checkpoints was not getting saved correctly which caused newer runs to not remove the previous "last" checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
- Fixed bug where the path for best checkpoints was not getting saved correctly when no metric was monitored which caused newer runs to not use the best checkpoint ([#11481](https://github.com/PyTorchLightning/pytorch-lightning/pull/11481))
- Fixed the format of the configuration saved automatically by the CLI's `SaveConfigCallback` ([#11532](https://github.com/PyTorchLightning/pytorch-lightning/pull/11532))

### Changed
Expand Down
18 changes: 8 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,11 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
return

filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
# set the last model path before saving because it will be part of the state.
previous, self.last_model_path = self.last_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.last_model_path and self.last_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.last_model_path)

self.last_model_path = filepath
if previous and previous != filepath:
trainer.training_type_plugin.remove_checkpoint(previous)

def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is None or self.save_top_k == 0:
Expand All @@ -692,12 +691,11 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
return

filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
trainer.save_checkpoint(filepath, self.save_weights_only)

if self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath:
trainer.training_type_plugin.remove_checkpoint(self.best_model_path)

self.best_model_path = filepath
if self.save_top_k == 1 and previous and previous != filepath:
trainer.training_type_plugin.remove_checkpoint(previous)

def _is_valid_monitor_key(self, metrics: Dict[str, _METRIC]) -> bool:
return self.monitor in metrics or len(metrics) == 0
Expand Down
27 changes: 27 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,3 +1240,30 @@ def test_model_checkpoint_saveload_ckpt(tmpdir):
assert getattr(cb_restore, key) == val
else:
assert getattr(cb_restore, key) != val


def test_save_last_saves_correct_last_model_path(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir, save_last=True)
mc.CHECKPOINT_NAME_LAST = "{foo}-last"
trainer = Trainer(callbacks=mc)
trainer.training_type_plugin.connect(BoringModel())

mc._save_last_checkpoint(trainer, {"foo": 1})
expected = "foo=1-last.ckpt"
assert os.listdir(tmpdir) == [expected]
full_path = str(tmpdir / expected)
ckpt = torch.load(full_path)
assert ckpt["callbacks"][mc.state_key]["last_model_path"] == full_path


def test_none_monitor_saves_correct_best_model_path(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir, monitor=None)
trainer = Trainer(callbacks=mc)
trainer.training_type_plugin.connect(BoringModel())

mc._save_none_monitor_checkpoint(trainer, {})
expected = "epoch=0-step=0.ckpt"
assert os.listdir(tmpdir) == [expected]
full_path = str(tmpdir / expected)
ckpt = torch.load(full_path)
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path
1 change: 1 addition & 0 deletions tests/deprecated_api/test_remove_1-7.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ def test_v1_7_0_resume_from_checkpoint_trainer_constructor(tmpdir):
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path == ckpt_path
trainer.fit(model)
ckpt_path = trainer.checkpoint_callback.best_model_path # last `fit` replaced the `best_model_path`
assert callback.state == 111
assert trainer.checkpoint_connector.resume_checkpoint_path is None
assert trainer.checkpoint_connector.resume_from_checkpoint_fit_path is None
Expand Down

0 comments on commit 9ebdc52

Please sign in to comment.