Skip to content

Commit

Permalink
Fix _determine_single_item_mode_from_directory by incorporating ste…
Browse files Browse the repository at this point in the history
…p name format.

PiperOrigin-RevId: 681933790
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Oct 3, 2024
1 parent 367ad4d commit 8bf99e3
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,8 @@ def _determine_single_item_mode_from_args(
return True


def _determine_single_item_mode_from_directory(
directory: epath.Path,
step: int,
) -> bool:
return (directory / str(step) / DEFAULT_ITEM_NAME).exists()
def _determine_single_item_mode_from_directory(step_path: epath.Path) -> bool:
return (step_path / DEFAULT_ITEM_NAME).exists()


class CheckpointManager(AbstractCheckpointManager, epy.ContextManager):
Expand Down Expand Up @@ -1337,7 +1334,7 @@ def restore(

if self._single_item is None:
self._single_item = _determine_single_item_mode_from_directory(
directory, step
self._get_read_step_directory(step, directory)
)
self._validate_args(items, args)

Expand Down Expand Up @@ -1399,14 +1396,12 @@ def item_metadata(self, step: int) -> Union[Any, args_lib.Composite]:
Composite of metadata for each item.
"""
assert isinstance(self._checkpointer.handler, CompositeCheckpointHandler)
read_step_directory = self._get_read_step_directory(step, self.directory)

result = self._checkpointer.metadata(
self._get_read_step_directory(step, self.directory)
)
result = self._checkpointer.metadata(read_step_directory)
if self._single_item is None:
self._single_item = _determine_single_item_mode_from_directory(
self.directory,
step,
read_step_directory
)
return self._maybe_get_default_item(result)

Expand Down

0 comments on commit 8bf99e3

Please sign in to comment.