Skip to content

Commit

Permalink
run sdpa with dtensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 33d3d0b6a19c747269aab1a95589bb61bf9c1f51
Pull Request resolved: #180
  • Loading branch information
tianyu-l committed Mar 30, 2024
1 parent dca7657 commit b67c664
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,12 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
input_layouts=(Shard(0), Replicate()),
desired_input_layouts=(Replicate(), Replicate()),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wq": col_parallel_strategy(use_local_output=False),
"attention.wk": col_parallel_strategy(use_local_output=False),
"attention.wv": col_parallel_strategy(use_local_output=False),
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": PrepareModuleInput(
Expand All @@ -192,11 +192,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down

0 comments on commit b67c664

Please sign in to comment.