Skip to content

Commit

Permalink
fixes in _partition_param_sec function (#5613)
Browse files Browse the repository at this point in the history
There are few fixes:
- When param.ds_secondary_tensor is not None and the param has not been
updated we don't need to update the param.ds_secondary_tensor.
- In HPU the 2nd tensor partition will always be completed before the
all-gather, so we don't need to add synchronize().
  • Loading branch information
mmhab authored Jun 11, 2024
1 parent a41729f commit b6e24ad
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,8 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
##support for NVME secondary param offload
#print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True)
if param.ds_status is ZeroParamStatus.AVAILABLE:
if param.ds_secondary_tensor is not None and not has_been_updated: ##param already partitioned
return
#check padding
tensor_size = self._aligned_size(param)
partition_size = tensor_size // self.dp_world_size
Expand Down Expand Up @@ -1702,7 +1704,8 @@ def _partition_param_sec(self, param, buffer=None, has_been_updated=False):
sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel))

# TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done
get_accelerator().current_stream().synchronize()
if not get_accelerator().resolves_data_dependency():
get_accelerator().current_stream().synchronize()

print_rank_0(f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}",
force=False)
Expand Down

0 comments on commit b6e24ad

Please sign in to comment.