Skip to content

Commit

Permalink
Improve the step time after emergency local restoration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675334608
  • Loading branch information
ChromeHearts authored and Orbax Authors committed Sep 16, 2024
1 parent ccdd58b commit 7cf465e
Showing 1 changed file with 59 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def _maybe_save_process_metadata(
return True


def should_restore_mesh_from_metadata(path: epath.Path) -> bool:
def _should_restore_mesh_from_metadata(path: epath.Path) -> bool:
metadata_path = path / _PROCESS_METADATA_FOLDER
return metadata_path.exists()


def consistent_restore_mesh_from_metadata(
def _consistent_restore_mesh_from_metadata(
path: epath.Path, global_mesh: jax.sharding.Mesh
) -> jax.sharding.Mesh:
"""Create a mesh consistent with the saved metadata."""
Expand Down Expand Up @@ -948,8 +948,17 @@ def _restore_from_local(
step_stats.is_restoring_slice = is_restoring_slice
step_stats.in_primary_slice = self.in_primary_slice

restore_mesh = self._global_mesh
if _should_restore_mesh_from_metadata(self._persistent_directory):
logging.info(
'Found consistent_restore_mesh, using it for local restoration'
)
restore_mesh = _consistent_restore_mesh_from_metadata(
self._persistent_directory, self._global_mesh
)

slice_devices = multislice.slice_devices(
self._global_mesh,
restore_mesh,
replica_id=self._slice_id,
replica_axis_index=self._replica_axis_index,
)
Expand All @@ -961,7 +970,7 @@ def _get_single_slice_sharding(
):
ss_mesh_shape = [
1 if i == self._replica_axis_index else d
for i, d in enumerate(self._global_mesh.devices.shape)
for i, d in enumerate(restore_mesh.devices.shape)
]
slice_mesh = jax.sharding.Mesh(
slice_devices.reshape(ss_mesh_shape), mesh.axis_names
Expand Down Expand Up @@ -1039,7 +1048,7 @@ def create_zeros(shape_dtype_tup):
start_broadcast = time.time()
shared_states, _ = multislice.broadcast_one_replica_to_all(
in_tree,
self._global_mesh,
restore_mesh,
replica_axis_index=self._replica_axis_index,
is_source=is_restoring_slice,
)
Expand All @@ -1056,7 +1065,51 @@ def create_zeros(shape_dtype_tup):
self._logger.log_entry(dataclasses.asdict(step_stats))

logging.info('Finished broadcasting in %.2f', broadcast_elapsed_s)
return jax.tree.unflatten(tree_defs, shared_states)

if np.array_equal(restore_mesh.device_ids, self._global_mesh.device_ids):
finalized_shared_states = shared_states
else:
finalized_shared_states = self._consistent_restore_mesh_to_global_mesh(
shared_states
)

return jax.tree.unflatten(tree_defs, finalized_shared_states)

def _consistent_restore_mesh_to_global_mesh(self, shared_states) -> Any:
"""Transfers from consistent restore mesh to global mesh."""

# transfer to global_mesh
def transfer_to_global_mesh(x):
# TODO(b/367435655) add donate to device_put instead of block+delete
y = jax.device_put(
x,
device=jax.sharding.NamedSharding(self._global_mesh, x.sharding.spec),
)
y.block_until_ready()

# delete immediately to conserve memory
x.delete()
return y

logging.info('Transferring from consistent restore mesh to global mesh')

start_transfer = time.time()
finalized_shared_states = jax.tree.map(
transfer_to_global_mesh,
shared_states,
)
transfer_elapsed_s = time.time() - start_transfer
logging.info(
'Finished transferring from consistent restore mesh to global mesh'
' in %.2fs',
transfer_elapsed_s,
)
jax.monitoring.record_event_duration_secs(
'/orbax/emergency/checkpoint/read/transfer_global_shard_duration_secs',
transfer_elapsed_s,
)

return finalized_shared_states

def _restore_from_persistent(
self,
Expand Down

0 comments on commit 7cf465e

Please sign in to comment.