Skip to content

Commit

Permalink
Handle checkpoint dirpath suffix in NeptuneLogger (#18863)
Browse files Browse the repository at this point in the history
Co-authored-by: Siddhant Sadangi <[email protected]>
Co-authored-by: Sabine <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: awaelchli <[email protected]>
  • Loading branch information
5 people authored Nov 25, 2023
1 parent 1fcb4ae commit af852ff
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `ModelCheckpoint` not expanding the `dirpath` if it has the `~` (home) prefix ([#19058](https://github.com/Lightning-AI/lightning/pull/19058))


- Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863))



## [2.1.2] - 2023-11-15

Expand Down
9 changes: 5 additions & 4 deletions src/lightning/pytorch/loggers/neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,13 +557,14 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
def _get_full_model_name(model_path: str, checkpoint_callback: Checkpoint) -> str:
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
if hasattr(checkpoint_callback, "dirpath"):
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
model_path = os.path.normpath(model_path)
expected_model_path = os.path.normpath(checkpoint_callback.dirpath)
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
return filepath
return model_path
filepath, _ = os.path.splitext(model_path[len(expected_model_path) + 1 :])
return filepath.replace(os.sep, "/")
return model_path.replace(os.sep, "/")

@classmethod
def _get_full_model_names_from_exp_structure(cls, exp_structure: Dict[str, Any], namespace: str) -> Set[str]:
Expand Down
6 changes: 4 additions & 2 deletions tests/tests_pytorch/loggers/test_neptune.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,12 @@ def test_get_full_model_name():
os.path.join("foo", "bar", "key/in/parts.ext"),
SimpleCheckpoint(dirpath=os.path.join("foo", "bar")),
),
("key", os.path.join("../foo", "bar", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("../foo", "bar"))),
("key", os.path.join("foo", "key.ext"), SimpleCheckpoint(dirpath=os.path.join("./foo", "bar/../"))),
]

for expected_model_name, *key_and_path in test_input_data:
assert NeptuneLogger._get_full_model_name(*key_and_path) == expected_model_name
for expected_model_name, model_path, checkpoint in test_input_data:
assert NeptuneLogger._get_full_model_name(model_path, checkpoint) == expected_model_name


def test_get_full_model_names_from_exp_structure():
Expand Down

0 comments on commit af852ff

Please sign in to comment.