Skip to content

Commit

Permalink
[T170073014] Rewrite distributed examples for Tensor Parallel, Sequen…
Browse files Browse the repository at this point in the history
…ce Parallel, 2D (FSDP + TP) (#1201)

* update requirements.txt

* add torchrun support, move to init_device_mesh

* update twod fully working

* ensure proper dp group seeding for synth data

* swiglu model added

* sequential running of custom, auto, seq parallel models

* streamline to 2D TP only for two_d_parallel example

* sequence parallel working...needs init_device_mesh update

* seq parallel now using init_device_mesh

* tp and sp examples all working and updated

* updates from code review

* remove utils.py. Sample models created in example files

* remove originals.py, leftover imports, various updates from code review feedback.

* code linting via ruff

* code formatting via ruff

* move rank_log to utils.py, update example files

* move logging imports and config to log_utils, update examples with new import

* add gpu verification, update run_python_examples.sh

* update min gpu = 4 for fsdp+tp

* move gpu check to top of examples, but before import init_device_mesh to clear CI
  • Loading branch information
lessw2020 authored Nov 22, 2023
1 parent f0d6fc9 commit c4dc481
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 278 deletions.
170 changes: 170 additions & 0 deletions distributed/tensor_parallelism/fsdp_tp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

import os
from log_utils import rank_log, get_logger, verify_min_gpu_count


# ---- GPU check ------------
_min_gpu_count = 4

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from torch.distributed._tensor.device_mesh import init_device_mesh


"""
This is the script to test 2D Parallel which combines Tensor/Sequence
parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model
in the SPMD style. We show an E2E working flow from forward, backward
and optimization.
We enabled Fully Sharded Data Parallel + Tensor Parallel in
separate parallel dimensions:
Data Parallel ("dp") across hosts
Tensor Parallel ("tp") within each host
We use a simple diagram to illustrate below:
======================================================================
------------ ------------ ------------ ------------
| Host 1 | | Host 2 | | | | Host N |
| 8 GPUs | | 8 GPUs | | | | 8 GPUs |
| | | | | ... | | |
| (TP) | | (TP) | | | | (TP) |
|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7|
| | | | | | | .., 8N-1]|
| | | | | | | |
------------ ------------ ------------ ------------
FSDP:
[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1]
======================================================================
More details can be seen in the slide:
https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/
"""


def find_multiple(n: int, k: int) -> int:
"""function to find resizing multiple for SwiGLU MLP"""
if n % k == 0:
return n
return n + k - (n % k)


class MLP_swiglu(nn.Module):
"""SwiGLU to showcase a Llama style MLP model"""

def __init__(self, mlp_dim: int = 1024) -> None:
super().__init__()
hidden_dim = 4 * mlp_dim
scaled_hidden = int(2 * hidden_dim / 3)
rounded_hidden = find_multiple(scaled_hidden, 256)

self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False)
self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.silu(self.in_proj(x)) * self.gate_proj(x)
x = self.out_proj(x)
return x


"""
Main body of the demo of a basic version of tensor parallel by using
PyTorch native APIs.
"""
tp_size = 2
logger = get_logger()

# understand world topology
_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])


print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.")
assert (
_world_size % tp_size == 0
), f"World size {_world_size} needs to be divisible by TP size {tp_size}"


# create a sharding plan based on the given world_size.
dp_size = _world_size // tp_size

# Create a device mesh with 2 dimensions.
# First dim is the data parallel dimension
# Second dim is the tensor parallel dimension.
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")
tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]

# To support identical inputs for TP groups, we need the dp process group
dp_pg = device_mesh.get_dim_groups()[0]

# For TP, input needs to be same across all TP ranks.
# while for SP, input can be different across all ranks.
# We will use dp_rank for setting the random seed
# to mimic the behavior of the dataloader.
dp_rank = dist.get_rank(dp_pg)


# create model and move it to GPU with id rank
_mlp_dim = 1024
base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda")


# Custom parallelization plan for the swiglu MLP model
custom_tp_model = parallelize_module(
module=base_model_swiglu,
device_mesh=tp_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(),
"gate_proj": ColwiseParallel(),
"out_proj": RowwiseParallel(),
},
)

rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n")

# Init FSDP using the dp device mesh
sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True)

# Create an optimizer for the parallelized and sharded model.
lr = 3e-3
rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}")
optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True)

# Training loop:
# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
rank_log(_rank, logger, "\nStarting 2D training...")
num_iterations = 10
batch_size = 2

for i in range(num_iterations):
# seeding with dp_rank to ensure identical inputs for TP groups
torch.manual_seed(i + dp_rank)
inp = torch.rand(batch_size, _mlp_dim, device="cuda")

output = sharded_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"2D iter {i} complete")

rank_log(_rank, logger, "2D training successfully completed!")
22 changes: 22 additions & 0 deletions distributed/tensor_parallelism/log_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import logging
import torch

logging.basicConfig(
format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO
)

def get_logger():
return logging.getLogger(__name__)


def rank_log(_rank, logger, msg):
"""helper function to log only on global rank 0"""
if _rank == 0:
logger.info(f" {msg}")


def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_cuda = torch.cuda.is_available()
gpu_count = torch.cuda.device_count()
return has_cuda and gpu_count >= min_gpus
6 changes: 3 additions & 3 deletions distributed/tensor_parallelism/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Python dependencies required for running the example

--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cu113
--extra-index-url https://download.pytorch.org/whl/nightly/cu116
torch >= 1.14.0.dev0; sys_platform == "linux"
--extra-index-url https://download.pytorch.org/whl/nightly/cu118
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
torch >= 2.2.0.dev0; sys_platform == "linux"
13 changes: 13 additions & 0 deletions distributed/tensor_parallelism/run_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@

# To run samples:
# bash run_example.sh {file_to_run.py} {num_gpus}
# where file_to_run = example to launch. Default = 'fsdp_tp_example.py'
# num_gpus = num local gpus to use (must be at least 2). Default = 4

# samples to run include:
# sequence_parallel_example.py
# tensor_parallel_example.py
# fsdp_tp_example.py

echo "Launching ${1:-fsdp_tp_example.py} with ${2:-4} gpus"
torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-fsdp_tp_example.py}
148 changes: 87 additions & 61 deletions distributed/tensor_parallelism/sequence_parallel_example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
import argparse

import os
import sys
import torch
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed._tensor import Shard

from torch.distributed.tensor.parallel import (
parallelize_module,
ColwiseParallel,
RowwiseParallel,
)

from log_utils import rank_log, get_logger, verify_min_gpu_count


# ---- GPU check ------------
_min_gpu_count = 2

if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
sys.exit()
# ---------------------------

from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module
from utils import cleanup, setup, ToyModel

try:
from torch.distributed.tensor.parallel import (
SequenceParallel
)
SP_AVAILABLE = True
except BaseException as e:
pass
from torch.distributed._tensor.device_mesh import init_device_mesh



"""
Expand All @@ -33,51 +44,66 @@
"""


def demo_sp(rank, args):
"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
print(f"Running SP example on rank {rank}.")
setup(rank, args.world_size)

# create a sharding plan based on the given world_size.
device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size))

# create model and move it to GPU with id rank
model = ToyModel().cuda(rank)
# Create a optimizer for the parallelized module.
LR = 0.25
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
# Parallelize the module based on the given Parallel Style.
model = parallelize_module(model, device_mesh, SequenceParallel())

# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
for _ in range(args.iter_nums):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10).cuda(rank)
output = model(inp)
output.sum().backward()
optimizer.step()

cleanup()


if __name__ == "__main__":
n_gpus = torch.cuda.device_count()
parser = argparse.ArgumentParser()
# This is passed in via cmd
parser.add_argument("--world_size", type=int, default=n_gpus)
parser.add_argument("--iter_nums", type=int, default=10)
args = parser.parse_args()
# The main entry point is called directly without using subprocess
if n_gpus < 2:
print("Requires at least 2 GPUs to run.")
elif not SP_AVAILABLE:
print(
"PyTorch doesn't have Sequence Parallelism available,"
" need nightly build."
)
else:
mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True)
class ToyModel(nn.Module):
"""MLP based model"""

def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 32)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(32, 5)

def forward(self, x):
return self.out_proj(self.relu(self.in_proj(x)))


"""
Main body of the demo of a basic version of sequence parallel by using
PyTorch native APIs.
"""
logger = get_logger()

# create a device mesh based on the given world_size.
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),)
)

_rank = device_mesh.get_rank()

print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.")

rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}")

# create model and move it to GPU. Init_device_mesh has already assigned gpu ids...
model = ToyModel().to("cuda")

# Custom parallelization plan for the model
sp_model = parallelize_module(
module=model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
},
)


# Create a optimizer for the parallelized module.
lr = 0.25
optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True)


# Perform a num of iterations of forward/backward
# and optimizations for the sharded module.
num_iters = 10
rank_log(_rank, logger, "Sequence Parallel training starting...")

for i in range(num_iters):
# For SP, input can be different across all ranks.
inp = torch.rand(20, 10, device="cuda")
output = sp_model(inp)
output.sum().backward()
optimizer.step()
rank_log(_rank, logger, f"Sequence Parallel iter {i} completed")

rank_log(_rank, logger, "Sequence Parallel training completed!")
Loading

0 comments on commit c4dc481

Please sign in to comment.