diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9c06567ed100..a5c0c3340019 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -546,15 +546,10 @@ def _setup_for_real_optimizer(self): self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 - max_partition_numel = 0 for param in all_params: self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() - max_partition_numel = max(max_partition_numel, param.partition_numel()) - if self.offload_optimizer: - self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory( - torch.empty(max_partition_numel, device=self.device)) def _link_all_hp_params(self): for p in self.module.parameters(): @@ -1510,13 +1505,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L offload_fp32_gradients[i].append(grad_buffer.float()) offload_fp32_offsets[i].append(dest_offset) else: - buffer_numel = grad_buffer.numel() fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( - 0, dest_offset, buffer_numel) - self.pinned_grad_buffer[:buffer_numel].copy_( - grad_buffer.to(dtype=torch.float32, non_blocking=True)) - get_accelerator().synchronize() - fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True) + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.float()) # free the gradient if not get_accelerator().is_synchronized_device():