Skip to content

Commit

Permalink
run sdpa with dtensor
Browse files Browse the repository at this point in the history
ghstack-source-id: 43941c1ca0dfc7a04589a7513a110b877c217917
Pull Request resolved: #180
  • Loading branch information
tianyu-l committed Mar 30, 2024
1 parent dca7657 commit df14507
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ColwiseParallel,
parallelize_module,
PrepareModuleInput,
PrepareModuleOutput,
RowwiseParallel,
SequenceParallel,
)
Expand Down Expand Up @@ -143,15 +144,21 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
)

# 1. Parallelize the first embedding and the last linear proj layer
# 2. Parallelize the root norm layer over the sequence dim
# 3. Shard the first transformer block's inputs
# 2. Prepare the freq_cis in rotary embedding as dtensor
# 3. Parallelize the root norm layer over the sequence dim
# 4. Shard the first transformer block's inputs
model = parallelize_module(
model,
tp_mesh,
{
"embeddings.tok_embeddings": RowwiseParallel(
input_layouts=Replicate(),
),
"embeddings": PrepareModuleOutput(
output_layouts=(None, Replicate()),
desired_output_layouts=(None, Replicate()),
use_local_output=False,
),
"output": col_parallel_strategy(
input_layouts=Shard(0),
output_layouts=(
Expand All @@ -177,9 +184,9 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
input_layouts=(Shard(0), None),
desired_input_layouts=(Replicate(), None),
),
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wq": col_parallel_strategy(use_local_output=False),
"attention.wk": col_parallel_strategy(use_local_output=False),
"attention.wv": col_parallel_strategy(use_local_output=False),
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
"attention_norm": SequenceParallel(sequence_dim=0),
"feed_forward": PrepareModuleInput(
Expand All @@ -192,11 +199,6 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
"ffn_norm": SequenceParallel(sequence_dim=0),
}

# Adjust attention module to use the local number of heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size()
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size()

parallelize_module(
module=transformer_block,
device_mesh=tp_mesh,
Expand Down

0 comments on commit df14507

Please sign in to comment.