Skip to content

Commit

Permalink
Add random key support
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711866172
  • Loading branch information
ChromeHearts authored and Orbax Authors committed Jan 23, 2025
1 parent 359a0ed commit 38ac130
Showing 1 changed file with 119 additions and 12 deletions.
131 changes: 119 additions & 12 deletions checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
from jax.experimental import layout
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import future
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.arrays import subchunking
from orbax.checkpoint._src.arrays import types as arrays_types
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import sharding as sharding_metadata
Expand Down Expand Up @@ -77,6 +78,8 @@
_SHARDING_SUFFIX_RE = r'/\d+(\.\d+)*$' # /0, /0.0, /1.0.1, etc.
_ZARRAY_SUFFIX_RE = r'/\.zarray$'
_ZARRAY_SUFFIX = '/.zarray'
_ARRAY_EXTRA_METADATA_FILE = '_array_extra_metadata.json'
_RANDOM_KEY_IMPL = 'random_key_impl'


async def _assert_parameter_files_exist(
Expand Down Expand Up @@ -238,6 +241,20 @@ def _build_array_write_spec(
)


class _CommitFuture(future.Future):
"""Represents the result of a background commit."""

def __init__(self, coro, name: Optional[str] = None):
self._t = future.ThreadRaisingException(
name=name,
target=lambda: asyncio_utils.run_sync(coro),
)
self._t.start()

def result(self, timeout: Optional[int] = None) -> Any:
return self._t.join(timeout=timeout)


def check_input_arguments(*args):
l = None
for arg in args:
Expand Down Expand Up @@ -635,7 +652,7 @@ async def serialize(
_print_ts_debug_data(self._metadata_key, infos)
copied_values = [copy.deepcopy(v) for v in values]
return [
future.CommitFutureAwaitingContractedSignals(
_CommitFuture(
self._background_serialize(copied_values, infos, args),
name='np_type_handler',
)
Expand Down Expand Up @@ -1022,11 +1039,62 @@ async def _serialize_sharding(
serialized_sharding
)

async def _serialize_array_extra_metadata(
self, info: types.ParamInfo, metadata: Dict[str, Any]
):
"""Serializes extra array metadata."""

if info.parent_dir is None:
raise ValueError('parent_dir cannot be None')

kvstore_tspec = ts_utils.build_kvstore_tspec(
info.parent_dir.as_posix(),
name=_ARRAY_EXTRA_METADATA_FILE,
use_ocdbt=info.is_ocdbt_checkpoint,
)
tspec = {
'driver': 'json',
'kvstore': kvstore_tspec,
}
logging.info('_serialize_array_extra_metadata: tspec: %s', tspec)
t = await ts.open(
tspec,
open=True,
context=info.ts_context,
)
await t.write(metadata)

async def _deserialize_array_extra_metadata(self, info: types.ParamInfo):
"""Serializes extra array metadata."""

if info.parent_dir is None:
raise ValueError('parent_dir cannot be None')

kvstore_tspec = ts_utils.build_kvstore_tspec(
info.parent_dir.as_posix(),
name=_ARRAY_EXTRA_METADATA_FILE,
use_ocdbt=info.is_ocdbt_checkpoint,
)
tspec = {
'driver': 'json',
'kvstore': kvstore_tspec,
}
try:
t = await ts.open(
tspec,
context=info.ts_context,
)
return await t.read()
except ValueError:
# no extra metadata
return None

async def _background_serialize(
self,
values: Sequence[replica_slices.ReplicaSlices],
infos: Sequence[types.ParamInfo],
args: Sequence[types.SaveArgs],
array_extra_metadata: Dict[str, Any],
):
"""Runs serialization in a background thread."""
write_coros = []
Expand Down Expand Up @@ -1075,6 +1143,12 @@ async def _background_serialize(
process_index=multihost.process_index(),
)
)

if array_extra_metadata:
write_coros.append(
self._serialize_array_extra_metadata(infos[0], array_extra_metadata)
)

await asyncio.gather(*write_coros)
await sharding_metadata_txn.commit_async()
if ocdbt_transaction is not None:
Expand All @@ -1087,8 +1161,19 @@ async def serialize(
args: Optional[Sequence[types.SaveArgs]] = None,
) -> Sequence[future.Future]:
"""See superclass documentation."""
for v in values:
if (

ext_metadata = {}
arrays = []

for v, info in zip(values, infos):
if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key):
# a JAX random key
arrays.append(jax.random.key_data(v))
if multihost.is_primary_host(self._primary_host):
ext_metadata[info.name] = {
_RANDOM_KEY_IMPL: str(jax.random.key_impl(v))
}
elif (
isinstance(v, jax.Array)
and jax.process_count() > 1
and v.is_fully_addressable
Expand All @@ -1098,26 +1183,35 @@ async def serialize(
' obtained using pmap. Consider using'
' fully_replicated_host_local_array_to_global_array in'
' orbax/checkpoint/utils.py to convert your arrays into'
' serializable objects.'
f' serializable objects. Array.sharding: {v.sharding}'
)
args = args or [types.SaveArgs()] * len(values)
check_input_arguments(values, infos, args)

else:
# regular array
arrays.append(v)

args = args or [types.SaveArgs()] * len(arrays)
check_input_arguments(arrays, infos, args)

assert all([info.enable_pinned_host_transfer for info in infos]) or all(
[not info.enable_pinned_host_transfer for info in infos]
)

# Complete D2H transfer in parallel for each array.
values_on_host = replica_slices.transfer_arrays_to_host(
values,
arrays,
self._replica_id,
self._use_replica_parallel,
enable_pinned_host_transfer=infos[0].enable_pinned_host_transfer,
)

logging.info('extra_metadata: %s', ext_metadata)

return [
future.CommitFutureAwaitingContractedSignals(
self._background_serialize(values_on_host, infos, args),
self._background_serialize(
values_on_host, infos, args, ext_metadata
),
name='array_type_handler',
)
]
Expand Down Expand Up @@ -1221,17 +1315,30 @@ async def deserialize(
strict=arg.strict if hasattr(arg, 'strict') else True,
)
]

deserialize_ops.append(self._deserialize_array_extra_metadata(infos[0]))
ret = await asyncio.gather(*deserialize_ops)
ext_metadata = ret[-1].item() if ret[-1] else None
ret = ret[:-1]

logging.info('ext_metadata: %s, type=%s', ext_metadata, type(ext_metadata))
if ext_metadata:
for i, (info, v) in enumerate(zip(infos, ret)):
logging.info('info.name: %s, i=%s', info.name, i)
if meta := ext_metadata.get(info.name):
if impl := meta.get(_RANDOM_KEY_IMPL):
ret[i] = jax.random.wrap_key_data(v, impl=impl)
logging.info('ret[i] = %s', ret[i])

if logging.vlog_is_on(1):
for a in ret:
logging.vlog(
1,
'restored jax.Array.shape = %s, jax.array.dtype = %s,'
' jax.array.layout + %s',
' jax.array.layout = %s',
a.shape,
a.dtype,
a.layout,
getattr(a, 'layout', None),
)
_print_ts_debug_data(self._metadata_key, infos)

Expand Down Expand Up @@ -1561,7 +1668,7 @@ async def serialize(
del args
# Copy is not needed since strings are passed by value.
return [
future.CommitFutureAwaitingContractedSignals(
_CommitFuture(
self._background_serialize(values, infos),
name='string_type_handler',
)
Expand Down

0 comments on commit 38ac130

Please sign in to comment.