Skip to content

Commit

Permalink
Add Pipeline Parallel support
Browse files Browse the repository at this point in the history
- uses pipeline tracer frontend to extract a graph and partition it into
  chunks per stage
- hardcodes one schedule (1F1B) for now
- supports 1D parallelism currently.

WIP: support 2D/3D parallel and clean up seed-checkpoint ux

ghstack-source-id: 7055ffe515b79fa6edad58a72543d9bc8e866f80
Pull Request resolved: #161
  • Loading branch information
wconstab committed Apr 5, 2024
1 parent 66cc578 commit 0a6e841
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 14 deletions.
2 changes: 1 addition & 1 deletion run_llama_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh

NGPU=${NGPU:-"8"}
NGPU=${NGPU:-"2"}

# by default log just rank 0 output,
LOG_RANK=${LOG_RANK:-0}
Expand Down
30 changes: 28 additions & 2 deletions torchtrain/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Tuple

import torch

from pippy import annotate_split_points, pipeline, SplitPoint
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand Down Expand Up @@ -134,7 +134,29 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
the model must fit on GPU or CPU memory.
"""
if parallel_dims.pp_enabled:
raise NotImplementedError("PP not implemented yet.")
pp_mesh = world_mesh["pp"]
stage_idx = pp_mesh.get_local_rank()
layers_per_rank = len(model.layers) // parallel_dims.pp
for i in range(1, parallel_dims.pp):
annotate_split_points(
model,
{f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING},
)

# Get example input
label_shape = input_shape = (8, 2048) # TODO
input_ids = torch.randint(
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
)
labels = torch.randint(
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
)
print("input_ids: ", input_ids.shape, input_ids.dtype)
print("labels: ", labels.shape, labels.dtype)

# Create a pipeline representation from the model
pipe = pipeline(model, parallel_dims.pp, example_args=(input_ids,))
model = pipe.get_stage_module(stage_idx)

if parallel_dims.tp_enabled:
tp_mesh = world_mesh["tp"]
Expand Down Expand Up @@ -233,4 +255,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
logger.info("Applied FSDP to the model")

if parallel_dims.pp_enabled:
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
return pipe

return model
72 changes: 62 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import torch
import torch.nn.functional as F
from pippy.PipelineSchedule import PipelineScheduleGPipe
from pippy.PipelineStage import PipelineStage
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.tensor.parallel import loss_parallel

Expand Down Expand Up @@ -120,7 +122,9 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
# torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")
Expand All @@ -139,6 +143,14 @@ def main(job_config: JobConfig):
dp_rank = dp_mesh.get_local_rank()
else:
dp_degree, dp_rank = 1, 0

if parallel_dims.pp_enabled:
pp_mesh = world_mesh["pp"]
pp_degree = pp_mesh.size()
pp_rank = pp_mesh.get_local_rank()
else:
pp_degree, pp_rank = 1, 0

data_loader = build_dataloader_fn(
job_config.training.dataset,
job_config.training.dataset_path,
Expand Down Expand Up @@ -197,14 +209,38 @@ def loss_fn(pred, labels):
model = models_parallelize_fns[model_name](
model, world_mesh, parallel_dims, job_config
)
# allocate sharded model on GPU and initialize weights via DTensor
if parallel_dims.pp_enabled:
pipe_meta = model
model = pipe_meta.get_stage_module(pp_rank)

model.to_empty(device="cuda")
model.init_weights()

# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
# there are virtual stages
if parallel_dims.pp_enabled:
stage = PipelineStage(
pipe=pipe_meta,
stage_index=pp_rank,
device=device,
group=pp_mesh.get_group(),
)
pp_schedule = PipelineScheduleGPipe(
stage,
n_microbatches=parallel_dims.pp,
loss_fn=loss_fn,
)
else:
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
# becuase it can't find "embedding" layer, for example.

# allocate sharded model on GPU and initialize weights via DTensor
model.init_weights()

# build optimizer after applying parallelisms to the model
optimizer = build_optimizer(model, job_config)
scheduler = get_lr_scheduler(optimizer, job_config)

metric_logger = build_metric_logger(job_config)

# torch.compile model for improved performance
Expand Down Expand Up @@ -274,13 +310,30 @@ def loss_fn(pred, labels):

input_ids = input_ids.cuda()
labels = labels.cuda()

optimizer.zero_grad()

# forward / backward
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
if parallel_dims.pp_enabled:
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
losses = []
pp_schedule.step(target=labels, losses=losses)
else:
schedule.step()

# accumulate losses across pipeline microbatches
current_loss = (
torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0
)
else:
# forward / backward
pred = model(input_ids)
loss = loss_fn(pred, labels)
loss.backward()
current_loss = loss.item()

# clip gradients
torch.nn.utils.clip_grad_norm_(
Expand All @@ -291,7 +344,6 @@ def loss_fn(pred, labels):
optimizer.step()
scheduler.step()

current_loss = loss.item()
losses_since_last_log.append(current_loss)

# log metrics
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping
steps = 10
data_parallel_degree = -1
tensor_parallel_degree = 1
pipeline_parallel_degree = 1
pipeline_parallel_degree = 2
fp8_linear = ""
compile = false
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)
Expand Down

0 comments on commit 0a6e841

Please sign in to comment.