Skip to content

Commit

Permalink
Add local type handler registries and refactor PyTreeCheckpointHandler
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622318163
  • Loading branch information
Orbax Authors committed Apr 6, 2024
1 parent b9593bd commit d5cf598
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 97 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Local type handler registries.

### Changed
- The PyPi `orbax` package is deprecated in favor of domain-specific namespace
packages, namely `orbax-checkpoint` and `orbax-export`. Imports are unchanged,
Expand Down Expand Up @@ -126,4 +129,3 @@ auto-publish functionality.

- Fix mistaken usages of placeholder "AGGREGATED" where "NOT-AGGREGATED" would
be more appropriate. Ensure backwards compatibility is maintained.

39 changes: 27 additions & 12 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
SaveArgs = type_handlers.SaveArgs
ParamInfo = type_handlers.ParamInfo
TypeHandler = type_handlers.TypeHandler
TypeHandlerRegistry = type_handlers.TypeHandlerRegistry
AggregateHandler = aggregate_handlers.AggregateHandler
MsgpackHandler = aggregate_handlers.MsgpackHandler
LegacyTransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]]
Expand Down Expand Up @@ -132,14 +133,18 @@ def _keypath_from_metadata(keypath_serialized: Tuple[Dict[str, Any]]) -> Any: #
return tuple(keypath)


def _get_value_metadata(value: Any, save_arg: SaveArgs) -> Dict[str, Any]:
def _get_value_metadata(
value: Any,
save_arg: SaveArgs,
registry: TypeHandlerRegistry,
) -> Dict[str, Any]:
"""Gets JSON metadata for a given value."""
if utils.is_supported_empty_aggregation_type(value):
typestr = type_handlers.get_empty_value_typestr(value)
skip_deserialize = True
else:
try:
handler = type_handlers.get_type_handler(type(value))
handler = registry.get(type(value))
typestr = handler.typestr()
skip_deserialize = save_arg.aggregate
except ValueError:
Expand Down Expand Up @@ -497,7 +502,10 @@ def __post_init__(self):


def _batched_serialization_requests(
tree: PyTree, param_infos: PyTree, args: PyTree
tree: PyTree,
param_infos: PyTree,
args: PyTree,
registry: TypeHandlerRegistry,
) -> List[_BatchRequest]:
"""Gets a list of batched serialization or deserialization requests."""
grouped = {}
Expand All @@ -519,9 +527,9 @@ def _group_value(
if arg.restore_type is not None:
# Give user the chance to override restore_type if they want.
restore_type = arg.restore_type
handler = type_handlers.get_type_handler(restore_type)
handler = registry.get(restore_type)
else:
handler = type_handlers.get_type_handler(type(value))
handler = registry.get(type(value))
if handler not in grouped:
grouped[handler] = _BatchRequest(handler, [], [], [], [])
request = grouped[handler]
Expand Down Expand Up @@ -636,6 +644,7 @@ def __init__(
write_tree_metadata: bool = True,
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
):
"""Creates PyTreeCheckpointHandler.
Expand All @@ -660,6 +669,8 @@ def __init__(
primary_host: the host id of the primary host. Default to 0. If it's set
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used.
"""
self._aggregate_handler = MsgpackHandler(primary_host=primary_host)
if aggregate_filename is None:
Expand All @@ -672,6 +683,7 @@ def __init__(
self._write_tree_metadata = write_tree_metadata
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
self._type_handler_registry = type_handler_registry


if self._use_ocdbt:
Expand Down Expand Up @@ -762,8 +774,9 @@ async def async_save(
"""Saves a PyTree to a given directory.
This operation is compatible with a multi-host, multi-device setting. Tree
leaf values must be supported by type_handlers. Standard supported types
include Python scalars, `np.ndarray`, `jax.Array`, and strings.
leaf values must be supported by the type_handler_registry given in the
constructor. Standard supported types include Python scalars, `np.ndarray`,
`jax.Array`, and strings.
After saving, all files will be located in "directory/". The exact files
that are saved depend on the specific combination of options, including
Expand Down Expand Up @@ -805,6 +818,7 @@ async def async_save(
Returns:
A Future that will commit the data to `directory` when awaited. Copying
the data from its source will be awaited in this function.
"""
if args is None:
args = PyTreeSaveArgs(
Expand Down Expand Up @@ -834,7 +848,7 @@ def _maybe_set_default_save_args(value, args_):
return SaveArgs(aggregate=False)
# Empty values will still raise TypeHandler registry error if _METADATA is
# disabled. We will prompt users to enable _METADATA to avoid this error.
aggregate = not type_handlers.has_type_handler(type(value))
aggregate = not self._type_handler_registry.has(type(value))
return SaveArgs(aggregate=aggregate)

save_args = jax.tree_util.tree_map(
Expand Down Expand Up @@ -871,7 +885,7 @@ def _maybe_set_default_save_args(value, args_):
else:
serialize_ops = []
batch_requests = _batched_serialization_requests(
item, param_infos, save_args
item, param_infos, save_args, self._type_handler_registry,
)
for request in batch_requests:
serialize_ops += [
Expand Down Expand Up @@ -949,7 +963,7 @@ def _process_aggregated_value(info, meta, args):
)

batch_requests = _batched_serialization_requests(
structure, param_infos, restore_args
structure, param_infos, restore_args, self._type_handler_registry,
)
deserialized_batches = []
deserialized_batches_ops = []
Expand Down Expand Up @@ -1227,7 +1241,8 @@ def _write_metadata_file(
tuple_keypath = str(tuple([str(utils.get_key_name(k)) for k in keypath]))
flat_metadata_with_keys[tuple_keypath] = {
_KEY_METADATA_KEY: _get_keypath_metadata(keypath),
_VALUE_METADATA_KEY: _get_value_metadata(value, save_arg),
_VALUE_METADATA_KEY: _get_value_metadata(
value, save_arg, self._type_handler_registry),
}

metadata = {
Expand Down Expand Up @@ -1407,7 +1422,7 @@ def _get_user_metadata(self, directory: epath.Path) -> PyTree:

metadata_ops = []
for restore_type, param_infos in batched_param_infos.items():
handler = type_handlers.get_type_handler(restore_type)
handler = self._type_handler_registry.get(restore_type)
metadata_ops.append(handler.metadata(param_infos))

async def _get_metadata():
Expand Down
20 changes: 19 additions & 1 deletion checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import string
import time
import typing
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple

from absl import logging
from etils import epath
Expand Down Expand Up @@ -366,3 +366,21 @@ def register_type_handler(ty, handler, func):
type_handlers.register_type_handler(
ty, original_handler, func=func, override=True
)


@contextlib.contextmanager
def ocdbt_checkpoint_context(use_ocdbt: bool, ts_context: Any):
"""Use OCDBT driver within context."""
original_registry = list(
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY._type_registry # pylint: disable=protected-access
)
if use_ocdbt:
type_handlers.register_standard_handlers_with_options(
use_ocdbt=use_ocdbt, ts_context=ts_context
)
try:
yield
finally:
type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY._type_registry = ( # pylint: disable=protected-access
original_registry
)
Loading

0 comments on commit d5cf598

Please sign in to comment.