Skip to content

Commit

Permalink
Add data input distribution for model+data parallel. (#18721)
Browse files Browse the repository at this point in the history
* Add data input distribution for model+data parallel.

* Address review comments.
  • Loading branch information
qlzh727 authored Nov 2, 2023
1 parent 9197591 commit cf9b2f3
Showing 1 changed file with 45 additions and 15 deletions.
60 changes: 45 additions & 15 deletions keras/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,8 @@ def distribute_tensor(tensor, layout):
def distribute_data_input(inputs, layout):
"""Distribute the input data with the corresponding layout.
Note that the inputs here is a local worker batch, which has already
been sharded to 1/N of the global batch size (N being the number of
workers/processes).
Note that the inputs here is a local worker batch. Within the local worker,
the data need to be further partitioned to map to the each of the devices.
Args:
inputs: `jax.Array` that is already sharded to a local process size.
Expand All @@ -125,22 +124,53 @@ def distribute_data_input(inputs, layout):
if layout.is_fully_addressable:
return jax.device_put(inputs, layout)

# TODO(scottzhu): Add support for data+model parallel.
# We assume the data are batch parallel only for now.
# We need the jax mesh information to determine how to place the data
# on to each of the worker.
jax_mesh = layout.mesh
mesh_rank = len(jax_mesh.shape)
per_process_batch_size = inputs.shape[0]
num_local_replia = jax.local_device_count()
per_replica_batch_size = per_process_batch_size // num_local_replia

if per_process_batch_size % per_replica_batch_size != 0:
if mesh_rank == 1:
# This is data parallel mesh only. We will split the full data
# across the batch dim.
num_split = jax.local_device_count()
per_replica_batch_size = per_process_batch_size // num_split
if per_process_batch_size % per_replica_batch_size != 0:
raise ValueError(
f"The local batch size {per_process_batch_size} is not"
"divisible by the number of local replicas "
f"{num_split}"
)
global_batch_size = per_process_batch_size * jax.process_count()
per_replica_batches = np.split(inputs, num_split, axis=0)
elif mesh_rank == 2:
# Data+Model parallel
# In this case, we need to check if the mesh batch dim shape is large
# than number of local devices, so that we can decide whether a split
# is needed for the data, or a repeat/copy of the data is needed for
# each of the device.
# TODO(scottzhu): The mesh batch dim name is not available here, since
# we only have jax Mesh. We assume the first dim is for batch, and
# second dim is for model for now.
mesh_batch_dim_size = list(jax_mesh.shape.values())[0]
local_device_count = jax.local_device_count()
if mesh_batch_dim_size < local_device_count:
# No split needed, we only need to repeat here.
global_batch_size = per_process_batch_size
per_replica_batches = [inputs for _ in range(local_device_count)]
else:
# Note that global batch size is not simply per_process_batch_size *
# num_process. It actually depends on the model dim size.
global_batch_size = per_process_batch_size * (
mesh_batch_dim_size // local_device_count
)
per_replica_batches = np.split(inputs, local_device_count, axis=0)
else:
raise ValueError(
f"The local batch size {per_process_batch_size} is not"
"divisible by the number of local replica "
f"{num_local_replia}"
"Only 1D or 2D mesh is supported at the moment. "
f"Received mesh shape = {jax_mesh.shape}"
)
global_batch_size = per_process_batch_size * jax.process_count()
global_shape = (global_batch_size,) + inputs.shape[1:]
per_replica_batches = np.split(inputs, num_local_replia, axis=0)

global_shape = (global_batch_size,) + inputs.shape[1:]
global_batch_array = jax.make_array_from_single_device_arrays(
global_shape,
layout,
Expand Down

0 comments on commit cf9b2f3

Please sign in to comment.