diff --git a/blogs/comm-opt/README.md b/blogs/comm-opt/README.md new file mode 100644 index 000000000000..4767c4342816 --- /dev/null +++ b/blogs/comm-opt/README.md @@ -0,0 +1,82 @@ +
+ +# Communication Optimizations for Large-Scale Training + +
+ + +## Table of Contents +1. [Introduction](#introduction) +2. [Gradient AllReduce Optimization for ZeRO stages 1 and 2](#ar-opt) +3. [Optimizing Parameter All-Gather for ZeRO2 Training](#ag-opt) +4. [Optimizing AlltoAll for Sequence-Parallel Training](#sp-opt) + + +## 1. Introduction +Training LLMs on large datasets can be extremely costly both in terms of hardware resources and time. An important step to minimize such costs is to carefully combine an appropriate number of resources together with a scalable library that guarantees training completion within a time limit. In this post, we discuss a key aspect of the scalability features of DeepSpeed, the communication optimization. Communication collectives (e.g., all-reduce, all-gather, etc.) are critical pieces of many popular DeepSpeed technologies (e.g., ZeRO, MoE, AutoTP, etc.), and in the following sections we discuss our new optimizations of some of these collectives. These optimizations are available in DeepSpeed versions >= 0.x.x. + +## 2. Gradient AllReduce Optimization for ZeRO stages 1 and 2 + +Before diving into this optimization, let's take a step back and show some of the case studies that demonstrate the need. + +AllReduce operation is an important part of the training process. In ZeRO, we handle this in buckets, which can be configured to get good communication throughput. As the number of GPUs increases, we encounter smaller-partition AllReduces. In this case, the current bucketing scheme cannot help with the communication overhead. This mostly becomes an issue when training smaller-scale models (like Llama-7B) with large number of GPUs. + +For instance, when training a dense-7B architecture with Zero stages 1 or 2, we encounter a 1 and 2 second increase for the AllReduce time by increasing from 256 to 512 and 1024 A100 GPUs. This issue mostly arises from the fact that, the gradient-averaging happens with smaller partitions (#parameters / #GPUs) per-GPU rank. This issue gets more serious when training MoE architectures (3 - 12 second) for which the expert's parameters can be farther away due to the current parallelism layout of data and expert parallelism. + +In this section, we introduce two main optimization techniques for alleviating these communication bottleneck. + +First, Multi-rank bucketing for the same process group: for this optimization, we simply pack all data that requires to be reduced from different ranks into one big flattened tensor and call AllReduce instead of reduce operations. After the reduction, we scatter the right portion of data to the corresponding ranks. + +Second, add new layout for the expert-data parallelism: the default parallelism layout for MoE architecture (as shown in Fig 1) is planned in a way that the experts are placed first on E parallel GPUs and replicated D times (data-parallel). With this layout, we encounter slower AllReduce as data-parallel ranks are placed farther away especially when we have cross-rank communication. We call this layout E + D. + +
+
+ + *Fig 1: Different MoE parallel layout. left) E + D, which places the GPUs in EP dimension first before adding DP, right) D + E, that replicates each expert by DP size, before constructing EP. We get faster AllReduce for the second layout while increasing the AlltoAll time. It potentially results in faster e2e training time, as the communication volume for AllReduce (total parameter size) is normally much more than AlltoAll (MLP activation memory).*
+
+By changing this layout from E + D to D + E (shown in Fig 1), where we first replicate each expert by D times and then add them across expert-parallel dimension, we can reduce the AllReduce time substantially. On an A100-DGX cluster, where each node has 8 GPUs, we see about 8x reduction in cross-node infiniband communication-volume for the parameter update process, which are now processed faster using the intra-node NVLinks. Note that by adding this optimization, we increase the cost of AlltoAll happening for the MoE part of the model, however, we have seen that the performance benefit of AllReduce overweighs this cost. + +Table 1 summarizes the saving observed for training a 7B dense and a MoE architecture by using the optimized AllReduce scheme. After applying the multi-rank bucketing technique, we reduce the AllReduce time by 4x for dense architecture and 5x - 8x for the MoE one. In addition, we obtain an extra 3x saving using the new D + E layout for the MoE architecture. Therefore, we see higher performance gain on MoE architectures when using large number of GPUs. For instance, when training a 7B-base MoE architecture, we reduce iteration-time from 13 sec to 9.5 sec on 512 GPUs (37%) and from 16.1 sec to 5.1 sec on 1k-GPU setup (3.2x). +
+ +| | GPUs | AllReduce time | Iteration time | +|----------|:------:|:------:|:------:| +baseline (dense) | 1024| 1.2 | 5.4 +optimized (dense) | 1024| 0.36 | 4.5 +baseline (MoE) | 1024 | 11.5 | 16.1 +optimized (MoE) | 1024 | 0.45 | 5.1 + +Table 1. AllReduce saving observed for both dense and MoE architectures. + +
+ +## 3. Optimizing Parameter All-Gather for ZeRO2 Training + +The same as with AllReduce, all-gather takes longer as we have more partitions. As the parameters are stored in a flattened buffer for ZeRO stage-2, we can simply have a one call to all-gather the parameters into this tensor. + +When all-gathering the updated parameters at Zero-Stage2, the bucketing scheme uses several narrow operations and creates a list of tensors with the bucket size from each partition. We needed this scheme to align with the `all_gather` operation from PyTorch. +However, by adding the support for the `all_gather_into_tensor`, operation that has been added to the newer versions of PyTorch, we can simply have a kernel call to do the full-parameter all-gather. With this optimization, we see about 2x reduction in the step time for large-scale training. + +## 4. Optimizing AlltoAll for Sequence-Parallel Training + +For this part of the optimization, we add some fusion for the communication that is required for the DeepSpeed-Ulysses to provide a more scalable approach for when we increase the SP from 2 to 8 (for this study, we consider A100-DGX hardware, which has 8 GPUs per-node and by increasing the parallelism more than 8, we encounter performance-hit by the cross-node communication). + +These fusions are done at two levels: +1. Fuse the sequence AlltoAll for q,k, and v: we Scatter the heads using the mixed tensor rather than splitting them beforehand. For this part, we need to get some more information from the modeling side (such as the number of q and kv heads), to split the heads before calling AlltoAll. We have added some new changes on the Megatron-DeepSpeed repo that incorporate these changes for the sequence-parallelism. +2. Fuse the AlltoAll tensors and call the PyTorch's AlltoAll-sinlge API: we reshape the tensors for the scatter dimension and use a single tensor for AlltoAll which alleviates the overhead of using a list of tensors which requires a contiguous call for each element of the list. + +By adding these optimizations, we see about 10 to 15% speedup compared to the previous design, and obtain good scalability across different SP-degree and context-lengths. In the following table, we show the improvement achieved by using SP, when doubling the GPU-count and increasing the SP-degree. We obtain over 80% of efficiency when increasing from 256 to 512 GPUs using SP-2. Furthermore, by increasing the sequence-length and SP, while keeping the processed tokens similar, we achieve over 75% of efficiency for 2x more resources. On the other hand, if we can double the number of tokens (shown on the last row of table 2), we can improve the performance to 1.81x. + +
+ +| GPUs | bsz | seq | Tokens (M) | SP | Sample (4K)-per-second | Speedup (x) | +|----------|:------:|:------:|:------:|:------:|:------:|:------:| +256 | 256| 8192 |2|1 | 60.71 |1 +512 | 256| 8192 |2|2 | 111.18 | 1.83 +512 | 128| 16384 |2|4 | 108.81 | 1.79 +512 | 64 |32768 |2|8 | 106.54 | 1.75 +512 | 64 |65536 |4|8 | 110.05 | 1.81 + +Table 2. Sequence-Parallelism scalability using DeepSpeed-Ulysses. + +
diff --git a/blogs/comm-opt/assets/images/e+d.png b/blogs/comm-opt/assets/images/e+d.png new file mode 100644 index 000000000000..72ad0f583857 Binary files /dev/null and b/blogs/comm-opt/assets/images/e+d.png differ diff --git a/blogs/comm-opt/assets/images/sp+fp.png b/blogs/comm-opt/assets/images/sp+fp.png new file mode 100644 index 000000000000..0b2940418f7a Binary files /dev/null and b/blogs/comm-opt/assets/images/sp+fp.png differ diff --git a/blogs/comm-opt/assets/images/sp-conv.png b/blogs/comm-opt/assets/images/sp-conv.png new file mode 100644 index 000000000000..e1e36b4436a0 Binary files /dev/null and b/blogs/comm-opt/assets/images/sp-conv.png differ diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index 89fe2bb46c3c..7dd0c6bcb67d 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -79,20 +79,22 @@ def __init__(self, # coefficient is used for weighted sum of the output of expert and mlp self.coefficient = torch.nn.Linear(hidden_size, 2) - def set_deepspeed_parallelism(self): - self._create_process_groups() + def set_deepspeed_parallelism(self, use_data_before_expert_parallel_=False): + self._create_process_groups(use_data_before_expert_parallel_=use_data_before_expert_parallel_) - def _create_process_groups(self): + def _create_process_groups(self, use_data_before_expert_parallel_=False): # Create process group for a layer if needed if self.expert_group_name not in groups._get_expert_parallel_group_dict(): print(f"No existing process group found, creating a new group named: {self.expert_group_name}") if (groups.mpu is None) or (not self.enable_expert_tensor_parallelism): # Condition 1 - no groups.mpu means no tensor parallelism # Condition 2 - disabling expert tensor parallelism on purpose - groups._create_expert_and_data_parallel(self.ep_size) + groups._create_expert_and_data_parallel( + self.ep_size, use_data_before_expert_parallel_=use_data_before_expert_parallel_) else: # expert tensor parallelism is enabled - groups._create_expert_data_and_model_parallel(self.ep_size, mpu=groups.mpu) + groups._create_expert_data_and_model_parallel( + self.ep_size, mpu=groups.mpu, use_data_before_expert_parallel_=use_data_before_expert_parallel_) # Set the group handle for the MOELayer (deepspeed_moe) object self.deepspeed_moe._set_ep_group(groups._get_expert_parallel_group(self.expert_group_name)) diff --git a/deepspeed/profiling/flops_profiler/profiler.py b/deepspeed/profiling/flops_profiler/profiler.py index ddcabf0d29e5..79e682a73b90 100644 --- a/deepspeed/profiling/flops_profiler/profiler.py +++ b/deepspeed/profiling/flops_profiler/profiler.py @@ -1139,7 +1139,7 @@ def get_module_duration(module): duration = module.__duration__ if duration == 0: # e.g. ModuleList for m in module.children(): - duration += m.__duration__ + duration += get_module_duration(m) return duration diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index a68b56f7208a..b49469b94f11 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -544,6 +544,10 @@ def get_hybrid_engine_config(param_dict): return hybrid_engine_config +def get_expert_data_topo_config(param_dict): + return get_scalar_param(param_dict, USE_DATA_BEFORE_EXPERT_PARALLEL, USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT) + + def get_eigenvalue_config(param_dict): if get_quantize_enabled(param_dict): param_dict = param_dict[QUANTIZE_TRAINING] @@ -850,6 +854,7 @@ def _initialize_params(self, param_dict): self.eigenvalue_layer_num, ) = get_eigenvalue_config(param_dict) + self.use_data_before_expert_parallel_ = get_expert_data_topo_config(param_dict) self.hybrid_engine = get_hybrid_engine_config(param_dict) self.sparse_attention = get_sparse_attention(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index cc493ee007c5..96f2a38bd05c 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -433,3 +433,9 @@ class ValidationMode: ######################################### DATA_PARALLEL_GROUP = "data_parallel_group" GLOBAL_RANK = "global_rank" + +######################################### +# EXPERT-DATA PARALLELISM TOPO Config +######################################### +USE_DATA_BEFORE_EXPERT_PARALLEL = "use_data_before_expert_parallelism" +USE_DATA_BEFORE_EXPERT_PARALLEL_DEFAULT = False diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 10f2ec7b159c..b824212de32b 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -715,6 +715,9 @@ def mics_shard_size(self): def zero_reduce_bucket_size(self): return self._config.zero_config.reduce_bucket_size + def zero_multi_rank_bucket_allreduce(self): + return self._config.zero_config.use_multi_rank_bucket_allreduce + def zero_allgather_bucket_size(self): return self._config.zero_config.allgather_bucket_size @@ -1112,7 +1115,7 @@ def _configure_distributed_model(self, model): # Set deepspeed parallelism spec. for the model including expert parallelism for _, module in self.module.named_modules(): if hasattr(module, 'set_deepspeed_parallelism'): - module.set_deepspeed_parallelism() + module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_) # Query the groups module to get information about various parallel groups self.local_all_to_all_group = None @@ -1495,6 +1498,7 @@ def _configure_zero_optimizer(self, optimizer): clip_grad=self.gradient_clipping(), contiguous_gradients=contiguous_gradients, reduce_bucket_size=self.zero_reduce_bucket_size(), + use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(), allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None, diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index b066090265c8..108c7775530b 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -937,6 +937,15 @@ def align_dense_tensors(tensor_list, alignment): return padded_tensor_list +def all_gather_all_partitions(global_flatten_group, partitioned_param_groups, dp_process_group): + for group_id, partitioned_params in enumerate(partitioned_param_groups): + # Sequential AllGather Best of both worlds + partition_id = dist.get_rank(group=dp_process_group[group_id]) + dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) + dist.all_gather_into_tensor(global_flatten_group[group_id], partitioned_params[partition_id], + dp_process_group[group_id]) + + def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_alignment_factor, allgather_bucket_size): for group_id, partitioned_params in enumerate(partitioned_param_groups): # Sequential AllGather Best of both worlds diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index f16dfd7ac4c0..76583c129cb9 100644 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -21,6 +21,7 @@ "stage3_max_live_parameters" : 1000000000, "stage3_max_reuse_distance" : 1000000000, "allgather_partitions": [true|false], + "use_multi_rank_bucket_allreduce": [true|false], "allgather_bucket_size": 500000000, "reduce_scatter": [true|false], "contiguous_gradients" : [true|false] @@ -107,6 +108,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel): for the allgather for large model sizes """ + use_multi_rank_bucket_allreduce: bool = True + """ + Combine the reduce buckets of the different ranks and do an All-Reduce instead of multiple Reduce ops. + This feature is useful when the model is small and we want to scale it on too many GPUs which therefore + reduces the message sizes of each packet. + """ + allgather_partitions: bool = True """ Chooses between allgather collective or a series of broadcast collectives diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 6fd1ac14ec51..4b92fe319bfa 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -13,7 +13,8 @@ from deepspeed.runtime import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler from deepspeed.runtime.utils import (bwc_tensor_model_parallel_rank, get_global_norm, empty_cache, see_memory_usage, - inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) + inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups, + all_gather_all_partitions) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum @@ -115,6 +116,7 @@ def __init__(self, verbose=True, contiguous_gradients=True, reduce_bucket_size=500000000, + use_multi_rank_bucket_allreduce=True, allgather_bucket_size=5000000000, dp_process_group=None, expert_parallel_group=None, @@ -391,6 +393,7 @@ def __init__(self, self.first_offset.append(first_offset) self.reduce_bucket_size = int(reduce_bucket_size) + self.use_multi_rank_bucket_allreduce = use_multi_rank_bucket_allreduce self.allgather_bucket_size = int(allgather_bucket_size) self.reduction_stream = None if get_accelerator().is_synchronized_device() else get_accelerator().Stream() @@ -953,6 +956,46 @@ def gradient_reduction_w_predivide(self, tensor): return tensor + def allreduce_and_copy_with_multiple_ranks(self, + small_bucket, + log=None, + divide=True, + process_group=None, + bucket_ranks=None): + process_group = self.dp_process_group if process_group is None else process_group + allreduced = self.allreduce_bucket(small_bucket, log=log, divide=divide, process_group=process_group) + for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks): + if dist.get_rank(group=process_group) == bucket_rank: + buf.copy_(synced) + + def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, divide=True, process_group=None): + small_bucket = [] + small_bucket_ranks = [] + numel = 0 + allreduce_sizes = [] + + for i, bucket_elem in enumerate(bucket): + rank, tensor = bucket_elem + small_bucket.append(tensor) + small_bucket_ranks.append(rank) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy_with_multiple_ranks(small_bucket, + log=None, + divide=divide, + process_group=process_group, + bucket_ranks=small_bucket_ranks) + small_bucket = [] + small_bucket_ranks = [] + numel = 0 + + if len(small_bucket) > 0: + self.allreduce_and_copy_with_multiple_ranks(small_bucket, + log=None, + divide=divide, + process_group=process_group, + bucket_ranks=small_bucket_ranks) + def average_tensor(self, tensor): if self.overlap_comm: stream = self.reduction_stream @@ -1029,26 +1072,31 @@ def average_tensor(self, tensor): if not self.ipg_bucket_has_moe_params: tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) - tensor_to_reduce = tensor - if self.communication_data_type != tensor.dtype: - tensor_to_reduce = tensor.to(self.communication_data_type) - - async_handles = [] + buckets = {} for i, (dst, bucket_offset, numel) in enumerate(rank_and_offsets): - grad_slice = tensor_to_reduce.narrow(0, int(bucket_offset), int(numel)) - # if dist.get_rank() == 0: - # print(f"Rank {dist.get_rank()} rank offset id {i} real dp size {dist.get_world_size(group=real_dp_process_group[i])} and dst: {dst}") - # dist.barrier() - #dist.barrier() - dst_rank = dist.get_global_rank(real_dp_process_group[i], dst) - async_handle = dist.reduce(grad_slice, dst=dst_rank, group=real_dp_process_group[i], async_op=True) - async_handles.append(async_handle) - - for handle in async_handles: - handle.wait() - - if self.communication_data_type != tensor.dtype: - tensor.copy_(tensor_to_reduce) + grad_slice = tensor.narrow(0, int(bucket_offset), int(numel)) + bucket_key = real_dp_process_group[i] if self.use_multi_rank_bucket_allreduce else ( + dst, real_dp_process_group[i]) + if bucket_key not in buckets: + buckets[bucket_key] = [] + if self.use_multi_rank_bucket_allreduce: + buckets[bucket_key].append((dst, grad_slice)) + else: + buckets[bucket_key].append(grad_slice) + + for bucket_key in buckets: + if self.use_multi_rank_bucket_allreduce: + self.allreduce_and_scatter(buckets[bucket_key], + numel_per_bucket=self.reduce_bucket_size, + divide=self.ipg_bucket_has_moe_params, + process_group=bucket_key) + else: + dst, process_group = bucket_key + self.allreduce_no_retain(buckets[bucket_key], + numel_per_bucket=self.reduce_bucket_size, + rank=dst, + divide=self.ipg_bucket_has_moe_params, + process_group=process_group) ############################################################################## ############################# CPU Offload Methods############################# @@ -1391,10 +1439,12 @@ def set_none_gradients_to_zero(self, i, partition_id): param.grad = torch.zero_like(param) ######################Reduction Related Methods############################## - def allreduce_bucket(self, bucket, rank=None, log=None): + def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None): rank = None tensor = self.flatten(bucket) + process_group = self.dp_process_group if process_group is None else process_group + tensor_to_allreduce = tensor if pg_correctness_test or self.sequence_parallel_size > 1: @@ -1405,17 +1455,18 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) + if divide: + tensor_to_allreduce.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) if rank is None: # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + dist.all_reduce(tensor_to_allreduce, group=process_group) else: - global_rank = dist.get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + global_rank = dist.get_global_rank(process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=process_group) if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): + if rank is None or rank == dist.get_rank(group=process_group): tensor.copy_(tensor_to_allreduce) return tensor @@ -1427,7 +1478,8 @@ def _clear_previous_reduced_grads(self): self.previous_reduced_grads = None # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): + def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): + process_group = self.dp_process_group if process_group is None else process_group if self.overlap_comm: get_accelerator().synchronize() # It is safe to clear the previously reduced grads of other partitions @@ -1437,23 +1489,38 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + allreduced = self.allreduce_bucket( + small_bucket, + rank=rank, + log=log, + divide=divide, + process_group=process_group, + ) if rank is None or rank == dist.get_rank(group=self.dp_process_group): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - def allreduce_no_retain(self, bucket, numel_per_bucket=500000000, rank=None, log=None): + def allreduce_no_retain( + self, + bucket, + numel_per_bucket=500000000, + rank=None, + log=None, + divide=True, + process_group=None, + ): small_bucket = [] numel = 0 for tensor in bucket: small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) + self.allreduce_and_copy(small_bucket, rank=rank, log=None, divide=divide, process_group=process_group) small_bucket = [] + numel = 0 if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) + self.allreduce_and_copy(small_bucket, rank=rank, log=log, divide=divide, process_group=process_group) # allows using reduction of gradients instead of using all_reduce @@ -1799,10 +1866,15 @@ def step(self, closure=None): self.timers(OPTIMIZER_ALLGATHER_TIMER).start() # Gather the updated weights from everyone. # Then all partitions of the model parameters are updated and ready for next round forward. - all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups, - dp_process_group=self.real_dp_process_group, - start_alignment_factor=self.nccl_start_alignment_factor, - allgather_bucket_size=self.allgather_bucket_size) + if dist.has_all_gather_into_tensor(): + all_gather_all_partitions(global_flatten_group=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group) + else: + all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) self.timers(OPTIMIZER_ALLGATHER_TIMER).stop() @@ -1825,10 +1897,15 @@ def update_lp_params(self): # if i == 0: # print_rank_0(f'{fp32_partition[:10]=}', force=True) - all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups, - dp_process_group=self.real_dp_process_group, - start_alignment_factor=self.nccl_start_alignment_factor, - allgather_bucket_size=self.allgather_bucket_size) + if dist.has_all_gather_into_tensor(): + all_gather_all_partitions(global_flatten_group=self.bit16_groups_flat, + partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group) + else: + all_gather_dp_groups(partitioned_param_groups=self.parallel_partitioned_bit16_groups, + dp_process_group=self.real_dp_process_group, + start_alignment_factor=self.nccl_start_alignment_factor, + allgather_bucket_size=self.allgather_bucket_size) def _average_expert_grad_norms(self, norm_groups): for i, norm in enumerate(norm_groups): diff --git a/deepspeed/sequence/layer.py b/deepspeed/sequence/layer.py index d4953287d034..e1dbff87f4ec 100644 --- a/deepspeed/sequence/layer.py +++ b/deepspeed/sequence/layer.py @@ -12,6 +12,35 @@ import deepspeed.comm as dist +def single_all_to_all(input, scatter_idx, gather_idx, group): + seq_world_size = dist.get_world_size(group) + inp_shape = list(input.shape) + inp_shape[scatter_idx] = inp_shape[scatter_idx] // seq_world_size + if scatter_idx < 2: + input_t = input.reshape( + [seq_world_size, inp_shape[scatter_idx]] + \ + inp_shape[scatter_idx + 1:] + ).contiguous() + else: + # transpose groups of heads with the seq-len parallel dimension, so that we can scatter them! + input_t = input.reshape( + [-1, seq_world_size, inp_shape[scatter_idx]] + \ + inp_shape[scatter_idx + 1:] + ).transpose(0, 1).contiguous() + + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + + # if scattering the seq-dim, transpose the heads back to the original dimension + if scatter_idx < 2: + output = output.transpose(0, 1).contiguous() + + return output.reshape( + inp_shape[: gather_idx] + \ + [inp_shape[gather_idx] * seq_world_size,] + \ + inp_shape[gather_idx + 1:]).contiguous() + + class _SeqAllToAll(torch.autograd.Function): @staticmethod @@ -21,13 +50,7 @@ def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, scatter_idx: int, ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx - seq_world_size = dist.get_world_size(group) - - input_list = [t.contiguous() for t in torch.tensor_split(input, seq_world_size, scatter_idx)] - output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] - # TODO Use all_to_all_single instead - dist.all_to_all(output_list, input_list, group=group) - return torch.cat(output_list, dim=gather_idx).contiguous() + return single_all_to_all(input, scatter_idx, gather_idx, group) @staticmethod def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: @@ -71,6 +94,7 @@ def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any) -> Tens * output (Tensor): context output """ # TODO Merge three alltoall calls into one + # TODO (Reza): change the api on the megatron-deepspeed side so that we only receive all data (q,k, and v) together! #in shape : e.g., [s/p:h:] query_layer = _SeqAllToAll.apply(self.spg, query, self.scatter_idx, self.gather_idx) key_layer = _SeqAllToAll.apply(self.spg, key, self.scatter_idx, self.gather_idx) diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index f20fd6bc319d..3ebe8cd75522 100644 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -110,7 +110,7 @@ def _create_model_parallel(model_parallel_size_): return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP -def _create_expert_and_data_parallel(expert_parallel_size_): +def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expert_parallel_=False): """ Create expert and data parallel groups. @@ -122,6 +122,7 @@ def _create_expert_and_data_parallel(expert_parallel_size_): expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE + use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology """ assert dist.is_initialized() @@ -136,29 +137,49 @@ def _create_expert_and_data_parallel(expert_parallel_size_): # Build the expert data parallel groups. global _EXPERT_DATA_PARALLEL_GROUP + ep_stride = world_size // expert_parallel_size_ + # Only create group if it does not already exist if group_name not in _EXPERT_DATA_PARALLEL_GROUP: for i in range(expert_parallel_size_): - ranks = range(i, world_size, expert_parallel_size_) + if use_data_before_expert_parallel_: + ranks = range(i * ep_stride, (i + 1) * ep_stride) + else: + ranks = range(i, world_size, expert_parallel_size_) group = dist.new_group(ranks) log_dist(f'Creating expert data parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if i == (rank % expert_parallel_size_): - _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + if use_data_before_expert_parallel_: + if i == (rank // ep_stride): + _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + else: + if i == (rank % expert_parallel_size_): + _EXPERT_DATA_PARALLEL_GROUP[group_name] = group # Build the expert parallel groups. global _EXPERT_PARALLEL_GROUP # Only create group if it does not already exist if group_name not in _EXPERT_PARALLEL_GROUP: - for i in range(world_size // expert_parallel_size_): - ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) - group = dist.new_group(ranks) - log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) - if i == (rank // expert_parallel_size_): - _EXPERT_PARALLEL_GROUP[group_name] = group - - -def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel_size_): + if use_data_before_expert_parallel_: + for i in range(ep_stride): + ranks = range(i, world_size, ep_stride) + group = dist.new_group(ranks) + log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) + if i == (rank % ep_stride): + _EXPERT_PARALLEL_GROUP[group_name] = group + else: + for i in range(world_size // expert_parallel_size_): + ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) + group = dist.new_group(ranks) + log_dist(f'creating expert parallel process group named {group_name} with ranks: {list(ranks)}', [0]) + if i == (rank // expert_parallel_size_): + _EXPERT_PARALLEL_GROUP[group_name] = group + + +def _get_expert_parallel_ranks(world_size, + model_parallel_size_, + expert_parallel_size_, + use_data_before_expert_parallel_=False): """Generate expert parallel and expert data parallel group ranks list. Example - E + M + D parallel @@ -174,7 +195,7 @@ def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel world_size (int): Distributed world size. model_parallel_size_ (int): Model parallel group size. expert_parallel_size_ (int): Expert parallel group size. - + use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology Returns: Expert parallel group ranks and Expert data parallel group ranks list. """ @@ -185,8 +206,19 @@ def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel # Generate data parallel groups data_parallel_groups = [] dp_group_size = model_parallel_size_ - for i in range(dp_group_size): - data_parallel_groups.append(list(range(i, world_size, dp_group_size))) + + if use_data_before_expert_parallel_: + dp_stride = world_size // expert_parallel_size_ // model_parallel_size_ + for i in range(dp_group_size): + data_parallel_groups.append(list()) + for ds in range(dp_stride): + # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30] + # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31] + data_parallel_groups[-1].extend( + list(range(i + ds * model_parallel_size_, world_size, dp_stride * model_parallel_size_))) + else: + for i in range(dp_group_size): + data_parallel_groups.append(list(range(i, world_size, dp_group_size))) expert_parallel_groups = [] expert_data_parallel_groups = [] @@ -204,7 +236,7 @@ def _get_expert_parallel_ranks(world_size, model_parallel_size_, expert_parallel return expert_parallel_groups, expert_data_parallel_groups -def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu): +def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False): """ Create expert and data parallel groups based on MPU (model parallel) group. @@ -249,7 +281,7 @@ def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu): # Need to check conditions outside the group creation loop because of the way torch.dist group creation works if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP: expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks( - world_size, model_parallel_size_, expert_parallel_size_) + world_size, model_parallel_size_, expert_parallel_size_, use_data_before_expert_parallel_) for ranks in expert_parallel_groups: group = dist.new_group(ranks) if rank in list(ranks):