diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index c3d76b29e9db6..d3a0559b1b69b 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -577,14 +577,14 @@ def _create_param_mapping(self): return param_mapping def _link_all_hp_params(self): - dp_world_size = dist.get_world_size(group=self.dp_process_group) if self.cpu_offload: self._get_offload_gradient_dict() for i, _ in enumerate(self.optimizer.param_groups): # Link bit16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) - partition_size = self.bit16_groups_flat[i].numel() // dp_world_size + partition_size = self.bit16_groups_flat[i].numel() // dist.get_world_size( + group=self.real_dp_process_group[i]) flat_hp_partition = self.single_partition_of_fp32_groups[i] link_hp_params(lp_param_list=self.bit16_groups[i], flat_hp_partition=flat_hp_partition,