From c4dc481e686b5cc7d1aad9e3977b91d8ba83bddd Mon Sep 17 00:00:00 2001 From: Less Wright Date: Wed, 22 Nov 2023 15:29:22 -0800 Subject: [PATCH] [T170073014] Rewrite distributed examples for Tensor Parallel, Sequence Parallel, 2D (FSDP + TP) (#1201) * update requirements.txt * add torchrun support, move to init_device_mesh * update twod fully working * ensure proper dp group seeding for synth data * swiglu model added * sequential running of custom, auto, seq parallel models * streamline to 2D TP only for two_d_parallel example * sequence parallel working...needs init_device_mesh update * seq parallel now using init_device_mesh * tp and sp examples all working and updated * updates from code review * remove utils.py. Sample models created in example files * remove originals.py, leftover imports, various updates from code review feedback. * code linting via ruff * code formatting via ruff * move rank_log to utils.py, update example files * move logging imports and config to log_utils, update examples with new import * add gpu verification, update run_python_examples.sh * update min gpu = 4 for fsdp+tp * move gpu check to top of examples, but before import init_device_mesh to clear CI --- .../tensor_parallelism/fsdp_tp_example.py | 170 ++++++++++++++++++ distributed/tensor_parallelism/log_utils.py | 22 +++ .../tensor_parallelism/requirements.txt | 6 +- distributed/tensor_parallelism/run_example.sh | 13 ++ .../sequence_parallel_example.py | 148 ++++++++------- .../tensor_parallel_example.py | 137 ++++++++------ .../two_d_parallel_example.py | 127 ------------- distributed/tensor_parallelism/utils.py | 31 ---- run_python_examples.sh | 9 +- 9 files changed, 385 insertions(+), 278 deletions(-) create mode 100644 distributed/tensor_parallelism/fsdp_tp_example.py create mode 100644 distributed/tensor_parallelism/log_utils.py create mode 100644 distributed/tensor_parallelism/run_example.sh delete mode 100644 distributed/tensor_parallelism/two_d_parallel_example.py delete mode 100644 distributed/tensor_parallelism/utils.py diff --git a/distributed/tensor_parallelism/fsdp_tp_example.py b/distributed/tensor_parallelism/fsdp_tp_example.py new file mode 100644 index 0000000000..bccd811d82 --- /dev/null +++ b/distributed/tensor_parallelism/fsdp_tp_example.py @@ -0,0 +1,170 @@ +import sys +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) + +import os +from log_utils import rank_log, get_logger, verify_min_gpu_count + + +# ---- GPU check ------------ +_min_gpu_count = 4 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + +from torch.distributed._tensor.device_mesh import init_device_mesh + + +""" +This is the script to test 2D Parallel which combines Tensor/Sequence +parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model +in the SPMD style. We show an E2E working flow from forward, backward +and optimization. + +We enabled Fully Sharded Data Parallel + Tensor Parallel in +separate parallel dimensions: + Data Parallel ("dp") across hosts + Tensor Parallel ("tp") within each host + + We use a simple diagram to illustrate below: + +====================================================================== +------------ ------------ ------------ ------------ +| Host 1 | | Host 2 | | | | Host N | +| 8 GPUs | | 8 GPUs | | | | 8 GPUs | +| | | | | ... | | | +| (TP) | | (TP) | | | | (TP) | +|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7| +| | | | | | | .., 8N-1]| +| | | | | | | | +------------ ------------ ------------ ------------ +FSDP: +[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] +====================================================================== + +More details can be seen in the slide: +https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ +""" + + +def find_multiple(n: int, k: int) -> int: + """function to find resizing multiple for SwiGLU MLP""" + if n % k == 0: + return n + return n + k - (n % k) + + +class MLP_swiglu(nn.Module): + """SwiGLU to showcase a Llama style MLP model""" + + def __init__(self, mlp_dim: int = 1024) -> None: + super().__init__() + hidden_dim = 4 * mlp_dim + scaled_hidden = int(2 * hidden_dim / 3) + rounded_hidden = find_multiple(scaled_hidden, 256) + + self.in_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) + self.gate_proj = nn.Linear(mlp_dim, rounded_hidden, bias=False) + self.out_proj = nn.Linear(rounded_hidden, mlp_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.in_proj(x)) * self.gate_proj(x) + x = self.out_proj(x) + return x + + +""" +Main body of the demo of a basic version of tensor parallel by using +PyTorch native APIs. +""" +tp_size = 2 +logger = get_logger() + +# understand world topology +_rank = int(os.environ["RANK"]) +_world_size = int(os.environ["WORLD_SIZE"]) + + +print(f"Starting PyTorch 2D (FSDP + TP) example on rank {_rank}.") +assert ( + _world_size % tp_size == 0 +), f"World size {_world_size} needs to be divisible by TP size {tp_size}" + + +# create a sharding plan based on the given world_size. +dp_size = _world_size // tp_size + +# Create a device mesh with 2 dimensions. +# First dim is the data parallel dimension +# Second dim is the tensor parallel dimension. +device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp")) + +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") +tp_mesh = device_mesh["tp"] +dp_mesh = device_mesh["dp"] + +# To support identical inputs for TP groups, we need the dp process group +dp_pg = device_mesh.get_dim_groups()[0] + +# For TP, input needs to be same across all TP ranks. +# while for SP, input can be different across all ranks. +# We will use dp_rank for setting the random seed +# to mimic the behavior of the dataloader. +dp_rank = dist.get_rank(dp_pg) + + +# create model and move it to GPU with id rank +_mlp_dim = 1024 +base_model_swiglu = MLP_swiglu(mlp_dim=_mlp_dim).to("cuda") + + +# Custom parallelization plan for the swiglu MLP model +custom_tp_model = parallelize_module( + module=base_model_swiglu, + device_mesh=tp_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(), + "gate_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, +) + +rank_log(_rank, logger, f"Model after parallelization {custom_tp_model=}\n") + +# Init FSDP using the dp device mesh +sharded_model = FSDP(custom_tp_model, device_mesh=dp_mesh, use_orig_params=True) + +# Create an optimizer for the parallelized and sharded model. +lr = 3e-3 +rank_log(_rank, logger, f"Creating AdamW optimizer with learning rate {lr}") +optimizer = torch.optim.AdamW(sharded_model.parameters(), lr=lr, foreach=True) + +# Training loop: +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +rank_log(_rank, logger, "\nStarting 2D training...") +num_iterations = 10 +batch_size = 2 + +for i in range(num_iterations): + # seeding with dp_rank to ensure identical inputs for TP groups + torch.manual_seed(i + dp_rank) + inp = torch.rand(batch_size, _mlp_dim, device="cuda") + + output = sharded_model(inp) + output.sum().backward() + optimizer.step() + rank_log(_rank, logger, f"2D iter {i} complete") + +rank_log(_rank, logger, "2D training successfully completed!") diff --git a/distributed/tensor_parallelism/log_utils.py b/distributed/tensor_parallelism/log_utils.py new file mode 100644 index 0000000000..f16d46526d --- /dev/null +++ b/distributed/tensor_parallelism/log_utils.py @@ -0,0 +1,22 @@ +import logging +import torch + +logging.basicConfig( + format="%(asctime)s %(message)s", datefmt="%m/%d/%Y %I:%M:%S %p", level=logging.INFO +) + +def get_logger(): + return logging.getLogger(__name__) + + +def rank_log(_rank, logger, msg): + """helper function to log only on global rank 0""" + if _rank == 0: + logger.info(f" {msg}") + + +def verify_min_gpu_count(min_gpus: int = 2) -> bool: + """ verification that we have at least 2 gpus to run dist examples """ + has_cuda = torch.cuda.is_available() + gpu_count = torch.cuda.device_count() + return has_cuda and gpu_count >= min_gpus diff --git a/distributed/tensor_parallelism/requirements.txt b/distributed/tensor_parallelism/requirements.txt index f7b8148247..c6b283a441 100644 --- a/distributed/tensor_parallelism/requirements.txt +++ b/distributed/tensor_parallelism/requirements.txt @@ -1,6 +1,6 @@ # Python dependencies required for running the example --pre ---extra-index-url https://download.pytorch.org/whl/nightly/cu113 ---extra-index-url https://download.pytorch.org/whl/nightly/cu116 -torch >= 1.14.0.dev0; sys_platform == "linux" \ No newline at end of file +--extra-index-url https://download.pytorch.org/whl/nightly/cu118 +--extra-index-url https://download.pytorch.org/whl/nightly/cu121 +torch >= 2.2.0.dev0; sys_platform == "linux" diff --git a/distributed/tensor_parallelism/run_example.sh b/distributed/tensor_parallelism/run_example.sh new file mode 100644 index 0000000000..c8d431505b --- /dev/null +++ b/distributed/tensor_parallelism/run_example.sh @@ -0,0 +1,13 @@ + +# To run samples: +# bash run_example.sh {file_to_run.py} {num_gpus} +# where file_to_run = example to launch. Default = 'fsdp_tp_example.py' +# num_gpus = num local gpus to use (must be at least 2). Default = 4 + +# samples to run include: +# sequence_parallel_example.py +# tensor_parallel_example.py +# fsdp_tp_example.py + +echo "Launching ${1:-fsdp_tp_example.py} with ${2:-4} gpus" +torchrun --nnodes=1 --nproc_per_node=${2:-4} --rdzv_id=101 --rdzv_endpoint="localhost:5972" ${1:-fsdp_tp_example.py} diff --git a/distributed/tensor_parallelism/sequence_parallel_example.py b/distributed/tensor_parallelism/sequence_parallel_example.py index 666713295f..3324d28d4a 100644 --- a/distributed/tensor_parallelism/sequence_parallel_example.py +++ b/distributed/tensor_parallelism/sequence_parallel_example.py @@ -1,19 +1,30 @@ -import argparse - +import os +import sys import torch -import torch.multiprocessing as mp +import torch.nn as nn + +from torch.distributed._tensor import Shard + +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) + +from log_utils import rank_log, get_logger, verify_min_gpu_count + + +# ---- GPU check ------------ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- -from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.parallel import parallelize_module -from utils import cleanup, setup, ToyModel -try: - from torch.distributed.tensor.parallel import ( - SequenceParallel - ) - SP_AVAILABLE = True -except BaseException as e: - pass +from torch.distributed._tensor.device_mesh import init_device_mesh + """ @@ -33,51 +44,66 @@ """ -def demo_sp(rank, args): - """ - Main body of the demo of a basic version of sequence parallel by using - PyTorch native APIs. - """ - print(f"Running SP example on rank {rank}.") - setup(rank, args.world_size) - - # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) - - # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) - # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - # Parallelize the module based on the given Parallel Style. - model = parallelize_module(model, device_mesh, SequenceParallel()) - - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - for _ in range(args.iter_nums): - # For SP, input can be different across all ranks. - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) - output.sum().backward() - optimizer.step() - - cleanup() - - -if __name__ == "__main__": - n_gpus = torch.cuda.device_count() - parser = argparse.ArgumentParser() - # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) - args = parser.parse_args() - # The main entry point is called directly without using subprocess - if n_gpus < 2: - print("Requires at least 2 GPUs to run.") - elif not SP_AVAILABLE: - print( - "PyTorch doesn't have Sequence Parallelism available," - " need nightly build." - ) - else: - mp.spawn(demo_sp, args=(args,), nprocs=args.world_size, join=True) +class ToyModel(nn.Module): + """MLP based model""" + + def __init__(self): + super().__init__() + self.in_proj = nn.Linear(10, 32) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(32, 5) + + def forward(self, x): + return self.out_proj(self.relu(self.in_proj(x))) + + +""" +Main body of the demo of a basic version of sequence parallel by using +PyTorch native APIs. +""" +logger = get_logger() + +# create a device mesh based on the given world_size. +device_mesh = init_device_mesh( + device_type="cuda", mesh_shape=(int(os.environ["WORLD_SIZE"]),) +) + +_rank = device_mesh.get_rank() + +print(f"Starting PyTorch Sequence Parallel example on rank {_rank}.") + +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") + +# create model and move it to GPU. Init_device_mesh has already assigned gpu ids... +model = ToyModel().to("cuda") + +# Custom parallelization plan for the model +sp_model = parallelize_module( + module=model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(input_layouts=Shard(0)), + "out_proj": RowwiseParallel(output_layouts=Shard(0)), + }, +) + + +# Create a optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(sp_model.parameters(), lr=lr, foreach=True) + + +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +num_iters = 10 +rank_log(_rank, logger, "Sequence Parallel training starting...") + +for i in range(num_iters): + # For SP, input can be different across all ranks. + inp = torch.rand(20, 10, device="cuda") + output = sp_model(inp) + output.sum().backward() + optimizer.step() + rank_log(_rank, logger, f"Sequence Parallel iter {i} completed") + +rank_log(_rank, logger, "Sequence Parallel training completed!") diff --git a/distributed/tensor_parallelism/tensor_parallel_example.py b/distributed/tensor_parallelism/tensor_parallel_example.py index 18133d8eea..2731e8046b 100755 --- a/distributed/tensor_parallelism/tensor_parallel_example.py +++ b/distributed/tensor_parallelism/tensor_parallel_example.py @@ -1,11 +1,27 @@ -import argparse - +import os +import sys import torch -import torch.multiprocessing as mp +import torch.nn as nn + +from torch.distributed.tensor.parallel import ( + parallelize_module, + ColwiseParallel, + RowwiseParallel, +) + +from log_utils import rank_log, get_logger, verify_min_gpu_count + +# ---- GPU check ------------ +_min_gpu_count = 2 + +if not verify_min_gpu_count(min_gpus=_min_gpu_count): + print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.") + sys.exit() +# --------------------------- + +from torch.distributed._tensor.device_mesh import init_device_mesh + -from torch.distributed._tensor import DeviceMesh -from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module -from utils import cleanup, setup, ToyModel """ @@ -39,49 +55,68 @@ Parallelism APIs in this example to show users how to use them. """ +class ToyModel(nn.Module): + """MLP based model""" + + def __init__(self): + super(ToyModel, self).__init__() + self.in_proj = nn.Linear(10, 32) + self.relu = nn.ReLU() + self.out_proj = nn.Linear(32, 5) + + def forward(self, x): + return self.out_proj(self.relu(self.in_proj(x))) -def demo_tp(rank, args): - """ - Main body of the demo of a basic version of tensor parallel by using - PyTorch native APIs. - """ - print(f"Running basic Megatron style TP example on rank {rank}.") - setup(rank, args.world_size) - - # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh("cuda", torch.arange(0, args.world_size)) - - # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) - # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - # Parallelize the module based on the given Parallel Style. - model = parallelize_module(model, device_mesh, PairwiseParallel()) - - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - for i in range(args.iter_nums): - # For TP, input needs to be same across all TP ranks. - # Setting the random seed is to mimic the behavior of dataloader. - torch.manual_seed(i) - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) - output.sum().backward() - optimizer.step() - - cleanup() - - -if __name__ == "__main__": - n_gpus = torch.cuda.device_count() - parser = argparse.ArgumentParser() - # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) - args = parser.parse_args() - # The main entry point is called directly without using subprocess - if n_gpus < 2: - print("Requires at least 2 GPUs to run.") - else: - mp.spawn(demo_tp, args=(args,), nprocs=args.world_size, join=True) + +""" +Main body of the demo of a basic version of tensor parallel by using +PyTorch native APIs. +""" +logger = get_logger() + +# create a device mesh based on the given world_size. +_world_size = int(os.environ["WORLD_SIZE"]) + +device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,)) +_rank = device_mesh.get_rank() + + +print(f"Starting PyTorch TP example on rank {_rank}.") +assert ( + _world_size % 2 == 0 +), f"TP examples require even number of GPUs, but got {_world_size} gpus" + +rank_log(_rank, logger, f"Device Mesh created: {device_mesh=}") + +# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids. +tp_model = ToyModel().to("cuda") + +# Create an optimizer for the parallelized module. +lr = 0.25 +optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True) + +# Custom parallelization plan for the model +tp_model = parallelize_module( + module=tp_model, + device_mesh=device_mesh, + parallelize_plan={ + "in_proj": ColwiseParallel(), + "out_proj": RowwiseParallel(), + }, +) +# Perform a num of iterations of forward/backward +# and optimizations for the sharded module. +num_iters = 10 +rank_log(_rank, logger, "Tensor Parallel training starting...") + +for i in range(num_iters): + # For TP, input needs to be same across all TP ranks. + # Setting the random seed is to mimic the behavior of dataloader. + torch.manual_seed(i) + inp = torch.rand(20, 10, device="cuda") + output = tp_model(inp) + output.sum().backward() + optimizer.step() + rank_log(_rank, logger, f"Tensor Parallel iter {i} completed") + +rank_log(_rank, logger, "Tensor Parallel training completed!") diff --git a/distributed/tensor_parallelism/two_d_parallel_example.py b/distributed/tensor_parallelism/two_d_parallel_example.py deleted file mode 100644 index 5c28db5adf..0000000000 --- a/distributed/tensor_parallelism/two_d_parallel_example.py +++ /dev/null @@ -1,127 +0,0 @@ -import argparse - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp - -from torch.distributed._tensor import DeviceMesh -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.tensor.parallel import ( - PairwiseParallel, - parallelize_module, -) -from torch.distributed.tensor.parallel.fsdp import enable_2d_with_fsdp - -from utils import cleanup, setup, ToyModel -try: - from torch.distributed.tensor.parallel import ( - SequenceParallel - ) - SP_AVAILABLE = True -except BaseException as e: - pass - - -""" -This is the script to test 2D Parallel which combines Tensor/Sequence -parallel with Fully Sharded Data Parallel (TP/SP + FSDP) on a toy model -in the SPMD style. We show an E2E working flow from forward, backward -and optimization. - -We enabled Fully Sharded Data Parallel + Tensor Parallel in -separate parallel dimensions: - Data Parallel across hosts - Tensor Parallel within each host - - We use a simple diagram to illustrate below: - -====================================================================== ------------- ------------ ------------ ------------ -| Host 1 | | Host 2 | | | | Host N | -| 8 GPUs | | 8 GPUs | | | | 8 GPUs | -| | | | | ... | | | -| (TP) | | (TP) | | | | (TP) | -|[0,1,..,7]| |[8,9..,15]| | | |[8N-8,8N-7| -| | | | | | | .., 8N-1]| -| | | | | | | | ------------- ------------ ------------ ------------ -FSDP: -[0, 8, ..., 8N-8], [1, 9, ..., 8N-7], ..., [7, 15, ..., 8N-1] -====================================================================== - -More details can be seen in the slide: -https://docs.google.com/presentation/d/17g6WqrO00rP3MsxbRENsPpjrlSkwiA_QB4r93_eB5is/ -""" - - -def demo_2d(rank, args): - """ - Main body of the demo of a basic version of tensor parallel by using - PyTorch native APIs. - """ - print(f"Running basic Megatron style TP example on rank {rank}.") - setup(rank, args.world_size) - assert ( - args.world_size % args.tp_size == 0 - ), "World size needs to be divisible by TP size" - - # create a sharding plan based on the given world_size. - device_mesh = DeviceMesh( - "cuda", torch.arange(0, args.world_size).view(-1, args.tp_size) - ) - - # create model and move it to GPU with id rank - model = ToyModel().cuda(rank) - # Create a optimizer for the parallelized module. - LR = 0.25 - optimizer = torch.optim.SGD(model.parameters(), lr=LR) - # Parallelize the module based on the given Parallel Style. - parallel_style = SequenceParallel() if args.run_seq_parallel else PairwiseParallel() - model = parallelize_module(model, device_mesh, parallel_style, tp_mesh_dim=1) - - # We need to register hooks for TP + FSDP integration. - assert ( - enable_2d_with_fsdp() - ), "FSDP 2D hook is not registered. Please use PyTorch with version >= 2.0" - dp_pg = device_mesh.get_dim_groups()[0] - model = FSDP(model, process_group=dp_pg) - - # Perform a num of iterations of forward/backward - # and optimizations for the sharded module. - for i in range(args.iter_nums): - # For TP, input needs to be same across all TP ranks. - # while for SP, input can be different across all ranks. - # Setting the random seed is to mimic the behavior of dataloader. - dp_rank = ( - rank - if args.run_seq_parallel - else dist.get_rank(dp_pg) - ) - torch.manual_seed(i + dp_rank) - inp = torch.rand(20, 10).cuda(rank) - output = model(inp) - output.sum().backward() - optimizer.step() - - cleanup() - - -if __name__ == "__main__": - n_gpus = torch.cuda.device_count() - parser = argparse.ArgumentParser() - # This is passed in via cmd - parser.add_argument("--world_size", type=int, default=n_gpus) - parser.add_argument("--iter_nums", type=int, default=10) - parser.add_argument("--run_seq_parallel", type=bool, default=False) - parser.add_argument("--tp_size", type=int, default=2) - args = parser.parse_args() - # The main entry point is called directly without using subprocess - if n_gpus < 4: - print("Requires at least 4 GPUs to run.") - elif not SP_AVAILABLE: - print( - "PyTorch doesn't have Sequence Parallelism available," - " need nightly build." - ) - else: - mp.spawn(demo_2d, args=(args,), nprocs=args.world_size, join=True) diff --git a/distributed/tensor_parallelism/utils.py b/distributed/tensor_parallelism/utils.py deleted file mode 100644 index a55f85c026..0000000000 --- a/distributed/tensor_parallelism/utils.py +++ /dev/null @@ -1,31 +0,0 @@ -import argparse -import os - -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn - - -def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) - - -def cleanup(): - dist.destroy_process_group() - - -class ToyModel(nn.Module): - def __init__(self): - super(ToyModel, self).__init__() - self.net1 = nn.Linear(10, 32) - self.relu = nn.ReLU() - self.net2 = nn.Linear(32, 5) - - def forward(self, x): - return self.net2(self.relu(self.net1(x))) diff --git a/run_python_examples.sh b/run_python_examples.sh index 1b45a281cf..a9ff393e80 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -63,8 +63,8 @@ function distributed() { start python tensor_parallelism/tensor_parallel_example.py || error "tensor parallel example failed" python tensor_parallelism/sequence_parallel_example.py || error "sequence parallel example failed" - python tensor_parallelism/two_d_parallel_example.py || error "2D parallel example failed" - python ddp/main.py || error "ddp example failed" + python tensor_parallelism/fsdp_tp_example.py || error "2D parallel example failed" + python ddp/main.py || error "ddp example failed" } function fast_neural_style() { @@ -96,7 +96,7 @@ function mnist() { python main.py --epochs 1 --dry-run || error "mnist example failed" } function mnist_forward_forward() { - start + start python main.py --epochs 1 --no_mps --no_cuda || error "mnist forward forward failed" } @@ -212,9 +212,8 @@ function clean() { function run_all() { # cpp dcgan - # distributed - fast_neural_style distributed + fast_neural_style imagenet mnist mnist_forward_forward