diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 1c7aa467bff01..6ab5a2223cd59 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058)) +- Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863)) + + ## [2.1.2] - 2023-11-15 diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index d8f8b36251c7f..6e8d268dff95a 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -557,13 +557,14 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str: """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" if hasattr(checkpoint_callback, "dirpath"): - expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" + model_path = os.path.normpath(model_path) + expected_model_path = os.path.normpath(checkpoint_callback.dirpath) if not model_path.startswith(expected_model_path): raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") # Remove extension from filepath - filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) - return filepath - return model_path + filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :]) + return filepath.replace(os.sep, "/") + return model_path.replace(os.sep, "/") @classmethod def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]: diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index 90f162f8c2564..9065a48478da7 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -284,10 +284,12 @@ def test_get_full_model_name(): os.path.join("foo", "bar", "key/in/parts.ext"), SimpleCheckpoint(dirpath=os.path.join("foo", "bar")), ), + ("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))), + ("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))), ] - for expected_model_name, *key_and_path in test_input_data: - assert NeptuneLogger._get_full_model_name(*key_and_path) == expected_model_name + for expected_model_name, model_path, checkpoint in test_input_data: + assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name def test_get_full_model_names_from_exp_structure():