diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index be153b4b4948..08ab05d79b6a 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2407,18 +2407,22 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads) if self.pipeline_parallelism: dp_group = self.mpu.get_data_parallel_group() + dp_world_size = dist.get_world_size(dp_group) else: dp_group = groups._get_sequence_data_parallel_group() - + dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group) + self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple - self.allreduce_no_retain(dense_bucket, dp_group=dp_group, numel_per_bucket=elements_per_buffer) + self.allreduce_no_retain(dense_bucket, + dp_group=dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): # to maintain the gradients value unaffected by ep_size setting, @@ -2490,9 +2494,9 @@ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): dp_world_size = dist.get_world_size(group=dp_group) if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / (dp_world_size / float(self.sequence_parallel_size))) + values.mul_(self.gradient_predivide_factor() / (dp_world_size)) else: - values.mul_(1. / (dp_world_size / float(self.sequence_parallel_size))) + values.mul_(1. / (dp_world_size)) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group)