From c022984227f4854704b553d3e2dc561933b0d499 Mon Sep 17 00:00:00 2001 From: SungMinCho Date: Fri, 27 May 2022 17:54:06 +0900 Subject: [PATCH] Make sharding plan explicit in DLRM example In previous code, DistributedModelParallel was responsible for creating sharding plans for DLRM. It relied on hard-coded constants to create the topology for planning (e.g. batch size=512, HBM_CAP=32GB and so on). This was problematic because it did not reflect the true system topology. Also, testing different constraints for sharding types and compute kernels were inconvenient in the previous code. Thus, we explicitly created sharding plans for DLRM before DMP and provided some simple options to make life easier. --- torchrec_dlrm/dlrm_main.py | 91 +++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/torchrec_dlrm/dlrm_main.py b/torchrec_dlrm/dlrm_main.py index 7025b450..ea188204 100644 --- a/torchrec_dlrm/dlrm_main.py +++ b/torchrec_dlrm/dlrm_main.py @@ -24,10 +24,27 @@ TOTAL_TRAINING_SAMPLES, ) from torchrec.datasets.utils import Batch +from torchrec.distributed.comm import get_local_size from torchrec.distributed import TrainPipelineSparseDist from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.model_parallel import DistributedModelParallel -from torchrec.distributed.types import ModuleSharder +from torchrec.distributed.planner import ( + EmbeddingShardingPlanner, + ParameterConstraints, + Topology, +) +from torchrec.distributed.planner.constants import ( + INTRA_NODE_BANDWIDTH, + CROSS_NODE_BANDWIDTH, + HBM_CAP, + DDR_CAP, +) +from torchrec.distributed.types import ( + ModuleSharder, + ShardingEnv, + ShardingType, +) from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizerWrapper from tqdm import tqdm @@ -197,18 +214,50 @@ def parse_args(argv: List[str]) -> argparse.Namespace: default=0.20, help="Learning rate after change point in first epoch.", ) - parser.set_defaults( - pin_memory=None, - mmap_mode=None, - shuffle_batches=None, - change_lr=None, - ) parser.add_argument( "--adagrad", dest="adagrad", action="store_true", help="Flag to determine if adagrad optimizer should be used.", ) + parser.add_argument( + "--sharding_type", + type=str, + choices=[st.value for st in ShardingType], + help="ShardingType constraint for all embedding tables" + ) + parser.add_argument( + "--compute_kernel", + type=str, + choices=[ck.value for ck in EmbeddingComputeKernel], + help="ComputeKernel constraint for all embedding tables" + ) + parser.add_argument( + "--intra_host_bw", + type=float, + default=INTRA_NODE_BANDWIDTH, + ) + parser.add_argument( + "--inter_host_bw", + type=float, + default=CROSS_NODE_BANDWIDTH, + ) + parser.add_argument( + "--hbm_cap", + type=int, + default=HBM_CAP, + ) + parser.add_argument( + "--ddr_cap", + type=int, + default=DDR_CAP, + ) + parser.set_defaults( + pin_memory=None, + mmap_mode=None, + shuffle_batches=None, + change_lr=None, + ) return parser.parse_args(argv) @@ -534,10 +583,38 @@ def main(argv: List[str]) -> None: EmbeddingBagCollectionSharder(fused_params=fused_params), ] + pg = dist.GroupMember.WORLD + assert pg is not None, "Process group is not initialized" + env = ShardingEnv.from_process_group(pg) + if any(a is not None for a in [args.sharding_type, args.compute_kernel]): + sharding_types = [args.sharding_type] if args.sharding_type else None + compute_kernels = [args.compute_kernel] if args.compute_kernel else None + constraints = { + f"t_{feature_name}": ParameterConstraints(sharding_types=sharding_types, compute_kernels=compute_kernels) + for feature_name in DEFAULT_CAT_NAMES + } + else: + constraints = None + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=env.world_size, + local_world_size=get_local_size(env.world_size), + compute_device=device.type, + hbm_cap=args.hbm_cap, + ddr_cap=args.ddr_cap, + intra_host_bw=args.intra_host_bw, + inter_host_bw=args.inter_host_bw, + batch_size=args.batch_size, + ), + constraints=constraints, + ) + plan = planner.collective_plan(train_model, sharders, pg) + model = DistributedModelParallel( module=train_model, device=device, sharders=cast(List[ModuleSharder[nn.Module]], sharders), + plan=plan, ) def optimizer_with_params():