diff --git a/CHANGELOG.md b/CHANGELOG.md index edc00fe26..73ea6e4c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- Local type handler registries. + ### Changed - The PyPi `orbax` package is deprecated in favor of domain-specific namespace packages, namely `orbax-checkpoint` and `orbax-export`. Imports are unchanged, @@ -126,4 +129,3 @@ auto-publish functionality. - Fix mistaken usages of placeholder "AGGREGATED" where "NOT-AGGREGATED" would be more appropriate. Ensure backwards compatibility is maintained. - diff --git a/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py index 66a55f975..87bfc76ce 100644 --- a/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py @@ -50,6 +50,7 @@ SaveArgs = type_handlers.SaveArgs ParamInfo = type_handlers.ParamInfo TypeHandler = type_handlers.TypeHandler +TypeHandlerRegistry = type_handlers.TypeHandlerRegistry AggregateHandler = aggregate_handlers.AggregateHandler MsgpackHandler = aggregate_handlers.MsgpackHandler LegacyTransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]] @@ -132,14 +133,18 @@ def _keypath_from_metadata(keypath_serialized: Tuple[Dict[str, Any]]) -> Any: # return tuple(keypath) -def _get_value_metadata(value: Any, save_arg: SaveArgs) -> Dict[str, Any]: +def _get_value_metadata( + value: Any, + save_arg: SaveArgs, + registry: TypeHandlerRegistry, +) -> Dict[str, Any]: """Gets JSON metadata for a given value.""" if utils.is_supported_empty_aggregation_type(value): typestr = type_handlers.get_empty_value_typestr(value) skip_deserialize = True else: try: - handler = type_handlers.get_type_handler(type(value)) + handler = registry.get(type(value)) typestr = handler.typestr() skip_deserialize = save_arg.aggregate except ValueError: @@ -497,7 +502,10 @@ def __post_init__(self): def _batched_serialization_requests( - tree: PyTree, param_infos: PyTree, args: PyTree + tree: PyTree, + param_infos: PyTree, + args: PyTree, + registry: TypeHandlerRegistry, ) -> List[_BatchRequest]: """Gets a list of batched serialization or deserialization requests.""" grouped = {} @@ -519,9 +527,9 @@ def _group_value( if arg.restore_type is not None: # Give user the chance to override restore_type if they want. restore_type = arg.restore_type - handler = type_handlers.get_type_handler(restore_type) + handler = registry.get(restore_type) else: - handler = type_handlers.get_type_handler(type(value)) + handler = registry.get(type(value)) if handler not in grouped: grouped[handler] = _BatchRequest(handler, [], [], [], []) request = grouped[handler] @@ -636,6 +644,7 @@ def __init__( write_tree_metadata: bool = True, use_zarr3: bool = False, primary_host: Optional[int] = 0, + type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY, ): """Creates PyTreeCheckpointHandler. @@ -660,6 +669,8 @@ def __init__( primary_host: the host id of the primary host. Default to 0. If it's set to None, then all hosts will be considered as primary. It's useful in the case that all hosts are only working with local storage. + type_handler_registry: a type_handlers.TypeHandlerRegistry. If not + specified, the global type handler registry will be used. """ self._aggregate_handler = MsgpackHandler(primary_host=primary_host) if aggregate_filename is None: @@ -672,6 +683,7 @@ def __init__( self._write_tree_metadata = write_tree_metadata self._use_zarr3 = use_zarr3 self._primary_host = primary_host + self._type_handler_registry = type_handler_registry if self._use_ocdbt: @@ -762,8 +774,9 @@ async def async_save( """Saves a PyTree to a given directory. This operation is compatible with a multi-host, multi-device setting. Tree - leaf values must be supported by type_handlers. Standard supported types - include Python scalars, `np.ndarray`, `jax.Array`, and strings. + leaf values must be supported by the type_handler_registry given in the + constructor. Standard supported types include Python scalars, `np.ndarray`, + `jax.Array`, and strings. After saving, all files will be located in "directory/". The exact files that are saved depend on the specific combination of options, including @@ -805,6 +818,7 @@ async def async_save( Returns: A Future that will commit the data to `directory` when awaited. Copying the data from its source will be awaited in this function. + """ if args is None: args = PyTreeSaveArgs( @@ -834,7 +848,7 @@ def _maybe_set_default_save_args(value, args_): return SaveArgs(aggregate=False) # Empty values will still raise TypeHandler registry error if _METADATA is # disabled. We will prompt users to enable _METADATA to avoid this error. - aggregate = not type_handlers.has_type_handler(type(value)) + aggregate = not self._type_handler_registry.has(type(value)) return SaveArgs(aggregate=aggregate) save_args = jax.tree_util.tree_map( @@ -871,7 +885,7 @@ def _maybe_set_default_save_args(value, args_): else: serialize_ops = [] batch_requests = _batched_serialization_requests( - item, param_infos, save_args + item, param_infos, save_args, self._type_handler_registry, ) for request in batch_requests: serialize_ops += [ @@ -949,7 +963,7 @@ def _process_aggregated_value(info, meta, args): ) batch_requests = _batched_serialization_requests( - structure, param_infos, restore_args + structure, param_infos, restore_args, self._type_handler_registry, ) deserialized_batches = [] deserialized_batches_ops = [] @@ -1227,7 +1241,8 @@ def _write_metadata_file( tuple_keypath = str(tuple([str(utils.get_key_name(k)) for k in keypath])) flat_metadata_with_keys[tuple_keypath] = { _KEY_METADATA_KEY: _get_keypath_metadata(keypath), - _VALUE_METADATA_KEY: _get_value_metadata(value, save_arg), + _VALUE_METADATA_KEY: _get_value_metadata( + value, save_arg, self._type_handler_registry), } metadata = { @@ -1407,7 +1422,7 @@ def _get_user_metadata(self, directory: epath.Path) -> PyTree: metadata_ops = [] for restore_type, param_infos in batched_param_infos.items(): - handler = type_handlers.get_type_handler(restore_type) + handler = self._type_handler_registry.get(restore_type) metadata_ops.append(handler.metadata(param_infos)) async def _get_metadata(): diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index 5472ca473..db5800083 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -20,7 +20,7 @@ import string import time import typing -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple from absl import logging from etils import epath @@ -366,3 +366,21 @@ def register_type_handler(ty, handler, func): type_handlers.register_type_handler( ty, original_handler, func=func, override=True ) + + +@contextlib.contextmanager +def ocdbt_checkpoint_context(use_ocdbt: bool, ts_context: Any): + """Use OCDBT driver within context.""" + original_registry = list( + type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY._type_registry # pylint: disable=protected-access + ) + if use_ocdbt: + type_handlers.register_standard_handlers_with_options( + use_ocdbt=use_ocdbt, ts_context=ts_context + ) + try: + yield + finally: + type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY._type_registry = ( # pylint: disable=protected-access + original_registry + ) diff --git a/checkpoint/orbax/checkpoint/type_handlers.py b/checkpoint/orbax/checkpoint/type_handlers.py index 8f4e0af7f..05bb3b9fb 100644 --- a/checkpoint/orbax/checkpoint/type_handlers.py +++ b/checkpoint/orbax/checkpoint/type_handlers.py @@ -23,7 +23,7 @@ import os import re import time -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, cast +from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence, Tuple, Union, cast import warnings from absl import logging @@ -1820,22 +1820,161 @@ async def deserialize( return await asyncio.gather(*read_ops) -_TYPE_REGISTRY = [ - (lambda ty: issubclass(ty, int), ScalarHandler()), - (lambda ty: issubclass(ty, float), ScalarHandler()), - (lambda ty: issubclass(ty, bytes), ScalarHandler()), - (lambda ty: issubclass(ty, np.number), ScalarHandler()), - (lambda ty: issubclass(ty, np.ndarray), NumpyHandler()), - (lambda ty: issubclass(ty, jax.Array), ArrayHandler()), - (lambda ty: issubclass(ty, str), StringHandler()), -] +class TypeHandlerRegistry(Protocol): + """A registry for TypeHandlers. + This internal base class is used for the global registry which serves as a + default for any type not found in a local registry. It is also accessed + through the module function get/set/has_type_handler. + """ + + def add( + self, + ty: Any, + handler: TypeHandler, + func: Optional[Callable[[Any], bool]] = None, + override: bool = False, + ): + """Registers a type for serialization/deserialization with a given handler. + + Note that it is possible for a type to match multiple different entries in + the registry, each with a different handler. In this case, only the first + match is used. + + Args: + ty: A type to register. + handler: a TypeHandler capable of reading and writing parameters of type + `ty`. + func: A function that accepts a type and returns True if the type should + be handled by the provided TypeHandler. If this parameter is not + specified, defaults to `lambda t: issubclass(t, ty)`. + override: if True, will override an existing mapping of type to handler. + + Raises: + ValueError if a type is already registered and override is False. + """ + ... + + def get(self, ty: Any) -> TypeHandler: + """Returns the handler registered for a given type, if available. + + Args: + ty: an object type (or string representation of the type.) + + Returns: + The TypeHandler that is registered for the given type. + + Raises: + ValueError if the given type has no registered handler. + """ + ... + + def has(self, ty: Any) -> bool: + """Checks if a type is registered. + + Args: + ty: an object type (or string representation of the type.) + + Returns: + A boolean indicating if ty is registered. + """ + ... + + +class _TypeHandlerRegistryImpl(TypeHandlerRegistry): + """The implementation for TypeHandlerRegistry.""" + + def __init__(self, *handlers: Tuple[Any, TypeHandler]): + """Create a type registry. + + Args: + *handlers: an optional list of handlers to initialize with. + """ + self._type_registry: List[Tuple[Callable[[Any], bool], TypeHandler]] = [] + self._typestr_registry: Dict[str, TypeHandler] = {} + if handlers: + for ty, h in handlers: + self.add(ty, h, override=True) + + def add( + self, + ty: Any, + handler: TypeHandler, + func: Optional[Callable[[Any], bool]] = None, + override: bool = False, + ): + if func is None: + func = lambda t: issubclass(t, ty) + + existing_handler_idx = None + for i, (f, _) in enumerate(self._type_registry): + if f(ty): + existing_handler_idx = i + # Ignore the possibility for subsequent matches, as these will not be + # used anyway. + break + + if existing_handler_idx is None: + if handler.typestr() in self._typestr_registry: + if override: + logging.warning('Type handler registry overriding type "%s" ' + 'collision on %s', ty, handler.typestr()) + else: + raise ValueError( + f'Type "{ty}" has a `typestr` ("{handler.typestr()}") which' + ' collides with that of an existing TypeHandler.' + ) + self._type_registry.append((func, handler)) + self._typestr_registry[handler.typestr()] = handler + elif override: + logging.warning('Type handler registry type "%s" overriding %s', + ty, handler.typestr()) + self._type_registry[existing_handler_idx] = (func, handler) + self._typestr_registry[handler.typestr()] = handler + else: + raise ValueError(f'A TypeHandler for "{ty}" is already registered.') -def _make_typestr_registry(type_registry: Any) -> Dict[str, TypeHandler]: - return {h.typestr(): h for _, h in type_registry} + def get(self, ty: Any) -> TypeHandler: + if isinstance(ty, str): + if ty in self._typestr_registry: + return self._typestr_registry[ty] + else: + for func, handler in self._type_registry: + if func(ty): + return handler + raise ValueError(f'Unknown type: "{ty}". Must register a TypeHandler.') + + def has(self, ty: Any) -> bool: + try: + self.get(ty) + return True + except ValueError: + return False + + +GLOBAL_TYPE_HANDLER_REGISTRY = _TypeHandlerRegistryImpl( + (int, ScalarHandler()), + (float, ScalarHandler()), + (bytes, ScalarHandler()), + (np.number, ScalarHandler()), + (np.ndarray, NumpyHandler()), + (jax.Array, ArrayHandler()), + (str, StringHandler()), +) -_TYPESTR_REGISTRY = _make_typestr_registry(_TYPE_REGISTRY) +def create_type_handler_registry( + *handlers: Tuple[Any, TypeHandler] +) -> TypeHandlerRegistry: + """Create a type registry. + + Args: + *handlers: an optional list of handlers to initialize with. + + Returns: + A TypeHandlerRegistry instance with only the specified handlers. + """ + return _TypeHandlerRegistryImpl(*handlers) def register_type_handler( @@ -1854,91 +1993,50 @@ def register_type_handler( ty: A type to register. handler: a TypeHandler capable of reading and writing parameters of type `ty`. - func: A function that accepts a type and returns True if the type should be - handled by the provided TypeHandler. If this parameter is not specified, - defaults to `lambda t: issubclass(t, ty)`. + func: A function that accepts a type and returns True if the type should + be handled by the provided TypeHandler. If this parameter is not + specified, defaults to `lambda t: issubclass(t, ty)`. override: if True, will override an existing mapping of type to handler. Raises: ValueError if a type is already registered and override is False. """ - if func is None: - func = lambda t: issubclass(t, ty) - - existing_handler_idx = None - for i, (f, _) in enumerate(_TYPE_REGISTRY): - if f(ty): - existing_handler_idx = i - # Ignore the possibility for subsequent matches, as these will not be used - # anyway. - break - - if existing_handler_idx is None: - if handler.typestr() in _TYPESTR_REGISTRY: - raise ValueError( - f'Type "{ty}" has a `typestr` ("{handler.typestr()}") which collides' - ' with that of an existing TypeHandler.' - ) - _TYPE_REGISTRY.append((func, handler)) - _TYPESTR_REGISTRY[handler.typestr()] = handler - elif override: - _TYPE_REGISTRY[existing_handler_idx] = (func, handler) - _TYPESTR_REGISTRY[handler.typestr()] = handler - else: - raise ValueError(f'A TypeHandler for "{ty}" is already registered.') + GLOBAL_TYPE_HANDLER_REGISTRY.add(ty, handler, func, override) def get_type_handler(ty: Any) -> TypeHandler: - """Returns the handler registered for a given type, if available. - - Args: - ty: an object type (or string representation of the type.) - - Returns: - The TypeHandler that is registered for the given type. - - Raises: - ValueError if the given type has no registered handler. - """ - if isinstance(ty, str): - if ty in _TYPESTR_REGISTRY: - return _TYPESTR_REGISTRY[ty] - else: - for func, handler in _TYPE_REGISTRY: - if func(ty): - return handler - raise ValueError(f'Unknown type: "{ty}". Must register a TypeHandler.') + """Returns the handler registered for a given type, if available.""" + return GLOBAL_TYPE_HANDLER_REGISTRY.get(ty) def has_type_handler(ty: Any) -> bool: - try: - get_type_handler(ty) - return True - except ValueError: - return False + """Returns if there is a handler registered for a given type.""" + return GLOBAL_TYPE_HANDLER_REGISTRY.has(ty) def register_standard_handlers_with_options(**kwargs): - """Re-registers a select set of handlers with the given options.""" + """Re-registers a select set of handlers with the given options. + + This is intended to override options en masse for the standard numeric + TypeHandlers and their corresponding types (scalars, numpy arrays and + jax.Arrays). + + Args: + **kwargs: keyword arguments to pass to each of the standard handlers. + """ # TODO(b/314258967): clean those up. del kwargs['use_ocdbt'], kwargs['ts_context'] - register_type_handler(int, ScalarHandler(**kwargs), override=True) - register_type_handler(float, ScalarHandler(**kwargs), override=True) - register_type_handler( - np.number, - ScalarHandler(**kwargs), - override=True, - ) - register_type_handler( - np.ndarray, - NumpyHandler(**kwargs), - override=True, - ) - register_type_handler( - jax.Array, - ArrayHandler(**kwargs), - override=True, - ) + GLOBAL_TYPE_HANDLER_REGISTRY.add(int, ScalarHandler(**kwargs), override=True) + GLOBAL_TYPE_HANDLER_REGISTRY.add(float, ScalarHandler(**kwargs), + override=True) + GLOBAL_TYPE_HANDLER_REGISTRY.add(bytes, ScalarHandler(**kwargs), + override=True) + GLOBAL_TYPE_HANDLER_REGISTRY.add(np.number, ScalarHandler(**kwargs), + override=True) + GLOBAL_TYPE_HANDLER_REGISTRY.add(np.ndarray, NumpyHandler(**kwargs), + override=True) + GLOBAL_TYPE_HANDLER_REGISTRY.add(jax.Array, ArrayHandler(**kwargs), + override=True) # TODO(b/253238305) Deprecate when all checkpoints have saved types. diff --git a/docs/api_reference/checkpoint.type_handlers.rst b/docs/api_reference/checkpoint.type_handlers.rst index ea3e2dd9a..affb43875 100644 --- a/docs/api_reference/checkpoint.type_handlers.rst +++ b/docs/api_reference/checkpoint.type_handlers.rst @@ -53,7 +53,9 @@ OCDBT functions TypeHandler registry ------------------------ +.. autoclass:: TypeHandlerRegistry +.. autofunction:: create_type_handler_registry .. autofunction:: register_type_handler .. autofunction:: get_type_handler .. autofunction:: has_type_handler -.. autofunction:: register_standard_handlers_with_options \ No newline at end of file +.. autofunction:: register_standard_handlers_with_options