Skip to content

Commit

Permalink
feat(jetstream): shard on batch when not possible to shard on other dim
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed Jan 16, 2025
1 parent 8b6eb13 commit 0514c0c
Showing 1 changed file with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,23 @@ def create_engine_env_data(
if model_info is None:
return None

# Sharding needs to happen evenly on the number of devices, so we first try to shard KV cache on the num_kv_heads,
# or head_dim. If that is not possible, we shard on the batch size, that can be adjusted to be a multiple of the
# number of devices.
num_devices = jax.device_count()
num_kv_heads_shardable = model_info.num_kv_heads % num_devices == 0
head_dim_shardable = model_info.num_kv_heads == 1 and model_info.head_dim % num_devices == 0

if num_kv_heads_shardable or head_dim_shardable:
shard_on_batch = False
else:
shard_on_batch = True
aligned_batch_size = (batch_size + num_devices - 1) // num_devices * num_devices
if aligned_batch_size != batch_size:
logger.info(
f"Adjusting batch size to be a multiple of the number of devices: {batch_size} -> {aligned_batch_size}"
)
batch_size = aligned_batch_size
max_cache_length = max_input_tokens + max_output_tokens

logger.info(f"Creating engine with max_cache_length={max_cache_length} = {max_input_tokens} + {max_output_tokens}")
Expand Down

0 comments on commit 0514c0c

Please sign in to comment.