Skip to content

Commit

Permalink
rename sequence_parallel to tensor_parallel (#162)
Browse files Browse the repository at this point in the history
This PR renames sequence_parallel to tensor_parallel, as sequence
parallel is only applied to rmsnorm layers, a broader name should be
tensor_parallel, maybe with sequence_parallel enabled

ghstack broken :( so using direct branch push instead
  • Loading branch information
wanchaol authored Mar 25, 2024
1 parent e84bbf4 commit 8dd5798
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 28 deletions.
4 changes: 2 additions & 2 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def __init__(self):
help="Data Parallelism degree. -1 means leftover ranks will be used (After SP/PP). 1 means disabled.",
)
self.parser.add_argument(
"--training.sequence_parallel_degree",
"--training.tensor_parallel_degree",
type=int,
default=1,
help="Sequence Parallelism degree. 1 means disabled.",
help="Tensor Parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.enable_loss_parallel",
Expand Down
22 changes: 11 additions & 11 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@dataclass
class ParallelDims:
dp: int
sp: int
tp: int
pp: int
world_size: int
enable_loss_parallel: bool
Expand All @@ -25,21 +25,21 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, sp, pp = self.dp, self.sp, self.pp
dp, tp, pp = self.dp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (sp * pp)
self.dp = dp = self.world_size // (tp * pp)
assert dp >= 1, dp
assert sp >= 1, sp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * sp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * sp({sp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.dp, self.sp, self.pp], ["dp", "sp", "pp"], strict=True
[self.dp, self.tp, self.pp], ["dp", "tp", "pp"], strict=True
):
if d > 1:
dims.append(d)
Expand All @@ -53,17 +53,17 @@ def dp_enabled(self):
return self.dp > 1

@property
def sp_enabled(self):
return self.sp > 1
def tp_enabled(self):
return self.tp > 1

@property
def pp_enabled(self):
return self.pp > 1

@property
def loss_parallel_enabled(self):
return self.sp > 1 and self.enable_loss_parallel
return self.tp > 1 and self.enable_loss_parallel

@cached_property
def model_parallel_size(self):
return self.sp * self.pp
return self.tp * self.pp
17 changes: 9 additions & 8 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,18 +145,19 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")

# First we apply Sequence Parallelism if it's enabled
if parallel_dims.sp_enabled:
tp_mesh = world_mesh["sp"]
sp_degree = job_config.training.sequence_parallel_degree
# First we apply Tensor Parallelism if it's enabled
if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
tp_degree = job_config.training.tensor_parallel_degree

row_parallel_strategy, col_parallel_strategy = get_tp_parallel_strategy(
job_config
)

# First:
# 1. parallelize the first embedding and the last linear proj layer
# 2. shard the first layer of transformer block
# 2. parallelize the root norm layer by sequence dim
# 3. shard the first layer of transformer block
model = parallelize_module(
model,
tp_mesh,
Expand All @@ -180,7 +181,7 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
},
)

# apply sequence parallelism to every transformer block
# apply tensor + sequence parallelism to every transformer block
for layer_id, transformer_block in enumerate(model.layers):
layer_plan = {
"attention": PrepareModuleInput(
Expand All @@ -204,8 +205,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):

# adjust num_heads in attention layer to local heads
attn_layer = transformer_block.attention
attn_layer.n_heads = attn_layer.n_heads // sp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // sp_degree
attn_layer.n_heads = attn_layer.n_heads // tp_degree
attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_degree

parallelize_module(
module=transformer_block,
Expand Down
7 changes: 4 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

import torch
import torch.nn.functional as F
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.tensor.parallel import loss_parallel
from torch.distributed.elastic.multiprocessing.errors import record

from torchtrain.checkpoint import CheckpointManager, IntervalType
from torchtrain.config_manager import JobConfig
Expand Down Expand Up @@ -98,7 +98,8 @@ def build_grad_scaler(model):

return ShardedGradScaler(enabled=enable_grad_scaling)

#Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html

# Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html
@record
def main(job_config: JobConfig):
init_logger()
Expand All @@ -113,7 +114,7 @@ def main(job_config: JobConfig):
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
sp=job_config.training.sequence_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.training.pipeline_parallel_degree,
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
sequence_parallel_degree = 1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ steps = 1000
# only dp would be sufficient for 7B
data_parallel_degree = -1
# 8-way TP, adjust to 2/4 for local(single host) runs
sequence_parallel_degree = 8
tensor_parallel_degree = 8
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ steps = 1000
# only dp would be sufficient for 7B
data_parallel_degree = -1
# 8-way TP
sequence_parallel_degree = 8
tensor_parallel_degree = 8
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ max_norm = 1.0 # grad norm clipping
steps = 1000
# only dp would be sufficient for 7B
data_parallel_degree = -1
sequence_parallel_degree = 1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
fp8_linear = ""
compile = false
Expand Down

0 comments on commit 8dd5798

Please sign in to comment.