Skip to content

Commit

Permalink
Add SaveDecisionPolicy to better encapsulate various options around…
Browse files Browse the repository at this point in the history
… choosing whether or not to perform a save at a particular step. Added a policy that allows for checkpointing as often as possible, as long as a save is not already in progress (continuous checkpointing).

PiperOrigin-RevId: 718984622
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jan 23, 2025
1 parent 302c2c1 commit d7c1cfd
Show file tree
Hide file tree
Showing 7 changed files with 281 additions and 22 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ properties not included in any tree mapping operations.
- Added github actions CI testing using Python versions 3.10-3.13.
- Standardize naming of the "custom metadata" field (user-supplied metadata) as
`custom_metadata`.
- Add `SaveDecisionPolicy` to better encapsulate various options around choosing
whether or not to perform a save at a particular step.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- `CommitFutureAwaitDirectorySignals`, `CommitFuture` and
`HandlerAwaitableSignal` for signalling between Checkpointing layers to enable
async directory creation.
- A policy that allows for checkpointing as often as possible, as long as a
save is not already in progress (continuous checkpointing).
- User-provided custom PyTree metadata.

### Fixed
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from orbax.checkpoint import args
from orbax.checkpoint import checkpoint_utils
from orbax.checkpoint import checkpointers
from orbax.checkpoint import checkpoint_managers
from orbax.checkpoint import handlers
from orbax.checkpoint import logging
from orbax.checkpoint import metadata
Expand Down
12 changes: 12 additions & 0 deletions checkpoint/orbax/checkpoint/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ py_library(
":args",
":arrays",
":checkpoint_manager",
":checkpoint_managers",
":checkpoint_utils",
":checkpointers",
":future",
Expand Down Expand Up @@ -130,6 +131,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/path:step",
"//checkpoint/orbax/checkpoint/_src/path:utils",
"//third_party/py/jax/experimental/array_serialization:serialization",
"//orbax/checkpoint/_src/checkpoint_managers:save_decision_policy",
"//orbax/checkpoint/_src/checkpointers:abstract_checkpointer",
"//orbax/checkpoint/_src/checkpointers:async_checkpointer",
"//orbax/checkpoint/_src/handlers:handler_registration",
Expand Down Expand Up @@ -350,3 +352,13 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/arrays:types",
],
)

py_library(
name = "checkpoint_managers",
srcs = ["checkpoint_managers.py"],
deps = [
":abstract_checkpoint_manager",
":checkpoint_manager",
"//orbax/checkpoint/_src/checkpoint_managers:save_decision_policy",
],
)
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from orbax.checkpoint import args
from orbax.checkpoint import checkpoint_utils
from orbax.checkpoint import checkpointers
from orbax.checkpoint import checkpoint_managers
from orbax.checkpoint import handlers
from orbax.checkpoint import logging
from orbax.checkpoint import metadata
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines policies for when a checkpoint is saved."""

import dataclasses
import typing
from typing import Container, Protocol, Sequence


@dataclasses.dataclass(kw_only=True)
class StepInfo:
"""Relevant information about a checkpoint step."""
step: int


@dataclasses.dataclass(kw_only=True)
class DecisionContext:
"""Additional properties for making a save decision."""

is_saving_in_progress: bool
reached_preemption: bool


@typing.runtime_checkable
class SaveDecisionPolicy(Protocol):
"""A policy that defines when to save a checkpoint.
Implementations should return True from `should_save` when saving a checkpoint
is desired at the given step.
"""

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
...


@dataclasses.dataclass
class FixedIntervalPolicy(SaveDecisionPolicy):
"""Checkpoint at a fixed interval."""

interval: int

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
del previous_steps
del context
return step.step % self.interval == 0


@dataclasses.dataclass
class SpecificStepsPolicy(SaveDecisionPolicy):
"""Checkpoint at specific steps."""

steps: Container[int]

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
del previous_steps
del context
return step.step in self.steps


class ContinuousCheckpointingPolicy(SaveDecisionPolicy):
"""Checkpoint as often as possible, as long as a save is not ongoing."""

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
del step
del previous_steps
return not context.is_saving_in_progress


class PreemptionCheckpointingPolicy(SaveDecisionPolicy):
"""Save a checkpoint when a preemption is detected."""

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
del step
del previous_steps
return context.reached_preemption


class InitialSavePolicy(SaveDecisionPolicy):
"""Checkpoint as soon as possible if no checkpoints already exist."""

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
del step
del context
return not previous_steps


@dataclasses.dataclass
class AnySavePolicy(SaveDecisionPolicy):
"""Evaluates all policies and saves if any of them returns True.
Each policy is evaluated in order, and if all are False, the final result is
False. If at least one is True, the final result is True.
"""

policies: Sequence[SaveDecisionPolicy]

def should_save(
self,
step: StepInfo,
previous_steps: Sequence[StepInfo],
*,
context: DecisionContext
) -> bool:
return any(
policy.should_save(step, previous_steps=previous_steps, context=context)
for policy in self.policies
)
97 changes: 75 additions & 22 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src.checkpoint_managers import save_decision_policy as save_decision_policy_lib
from orbax.checkpoint._src.checkpointers import abstract_checkpointer
from orbax.checkpoint._src.checkpointers import async_checkpointer
from orbax.checkpoint._src.checkpointers import checkpointer as checkpointer_lib
Expand Down Expand Up @@ -157,6 +158,52 @@ def join(self, *args, **kwargs):
raise self.exception


@dataclasses.dataclass
class _ShouldSaveFnPolicy(save_decision_policy_lib.SaveDecisionPolicy):
"""A policy that uses a user-provided should_save_fn."""

should_save_fn: Callable[[int, Optional[int]], bool]

def should_save(
self,
step: save_decision_policy_lib.StepInfo,
previous_steps: Sequence[save_decision_policy_lib.StepInfo],
*,
context: save_decision_policy_lib.DecisionContext,
) -> bool:
return self.should_save_fn(
step.step, previous_steps[-1].step if previous_steps else None
)


def _get_default_save_decision_policy(
options: CheckpointManagerOptions,
) -> save_decision_policy_lib.SaveDecisionPolicy:
"""Creates a default policy from CheckpointManagerOptions."""
save_interval_policies = []
if options.should_save_fn is not None:
save_interval_policies.append(_ShouldSaveFnPolicy(options.should_save_fn))
save_interval_policies.append(
save_decision_policy_lib.PreemptionCheckpointingPolicy()
)
else:
if options.save_interval_steps is not None:
save_interval_policies.append(
save_decision_policy_lib.FixedIntervalPolicy(
options.save_interval_steps
)
)
if options.save_on_steps is not None:
save_interval_policies.append(
save_decision_policy_lib.SpecificStepsPolicy(options.save_on_steps)
)
save_interval_policies.append(
save_decision_policy_lib.PreemptionCheckpointingPolicy()
)
save_interval_policies.append(save_decision_policy_lib.InitialSavePolicy())
return save_decision_policy_lib.AnySavePolicy(save_interval_policies)


# TODO(b/268051457) Clean up when no longer depended upon by internal users.
def is_async_checkpointer(checkpointer: AbstractCheckpointer):
return isinstance(
Expand Down Expand Up @@ -264,6 +311,12 @@ class CheckpointManagerOptions:
temporary_path_class:
Optional. The concrete `atomicity_types.TemporaryPath` class to be used by
the underlying `Checkpointer`.
save_decision_policy: An object used to determine when a checkpoint should be
saved. If provided, overrides any other options dealing with this subject,
including `save_interval_steps`, `save_on_steps`, and `should_save_fn`, and
is the sole means of determining when a checkpoint should be saved. If not
provided, these other options are used instead. Prefer to use this option
over others.
"""

save_interval_steps: int = 1
Expand Down Expand Up @@ -293,6 +346,9 @@ class CheckpointManagerOptions:
file_options: FileOptions = dataclasses.field(default_factory=FileOptions)
save_root_metadata: bool = True
temporary_path_class: Optional[Type[atomicity_types.TemporaryPath]] = None
save_decision_policy: Optional[
save_decision_policy_lib.SaveDecisionPolicy
] = None

def __post_init__(self):
if self.best_mode not in ('min', 'max'):
Expand Down Expand Up @@ -576,6 +632,10 @@ def __init__(

self._options = options or CheckpointManagerOptions()
self._multiprocessing_options = self._options.multiprocessing_options
self._save_decision_policy = (
self._options.save_decision_policy
or _get_default_save_decision_policy(self._options)
)

if self._options.best_mode not in ['min', 'max']:
raise ValueError('`best_mode` must be one of: "min", "max"')
Expand Down Expand Up @@ -992,32 +1052,25 @@ def should_save(self, step: int) -> bool:
if self._options.read_only:
logging.warning('%s is read only, save will be skipped', self.directory)
return False
if self.reached_preemption(step):
return True
last_checkpoint_step = self.latest_step()
# Ensure current step is between the last step and next step (accounting for
# save interval). The `last_checkpoint_step` may not be initialized, in
# which case we should save. Otherwise, step must fall on the specified
# save interval. This condition accounts for the possibility of saving
# on preemption, in which case we want to maintain the same save period as
# if preemption had not happened.
# save interval).
if last_checkpoint_step is not None and last_checkpoint_step >= step:
return False
# If present then prefer should_save_fn over other 'save_*' options.
if self._options.should_save_fn is not None:
logging.log_every_n(
logging.INFO,
'CheckpointManagerOptions.should_save_fn is available, following save'
' options will be ignored: save_interval_steps=%s and'
' save_on_steps=%s',
500,
self._options.save_interval_steps,
self._options.save_on_steps,
)
return self._options.should_save_fn(step, last_checkpoint_step)
return last_checkpoint_step is None or (
step % self._options.save_interval_steps == 0
or step in self._options.save_on_steps

is_saving_in_progress = self.is_saving_in_progress()
reached_preemption = self.reached_preemption(step)
previous_step_infos = [
save_decision_policy_lib.StepInfo(step=ckpt.step)
for ckpt in self._checkpoints
]
current_step_info = save_decision_policy_lib.StepInfo(step=step)
context = save_decision_policy_lib.DecisionContext(
is_saving_in_progress=is_saving_in_progress,
reached_preemption=reached_preemption,
)
return self._save_decision_policy.should_save(
current_step_info, previous_steps=previous_step_infos, context=context
)

def _get_save_directory(
Expand Down
33 changes: 33 additions & 0 deletions checkpoint/orbax/checkpoint/checkpoint_managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Public symbols for checkpoint_managers module."""

# pylint: disable=g-importing-member, g-bad-import-order, unused-import, g-multiple-import

from orbax.checkpoint._src.checkpoint_managers import save_decision_policy
from orbax.checkpoint._src.checkpoint_managers.save_decision_policy import (
StepInfo,
SaveDecisionPolicy,
FixedIntervalPolicy,
InitialSavePolicy,
SpecificStepsPolicy,
ContinuousCheckpointingPolicy,
AnySavePolicy,
)

from orbax.checkpoint.checkpoint_manager import CheckpointManagerOptions
from orbax.checkpoint.checkpoint_manager import CheckpointManager

from orbax.checkpoint.abstract_checkpoint_manager import AbstractCheckpointManager

0 comments on commit d7c1cfd

Please sign in to comment.