Skip to content

Commit

Permalink
cast dtype for allgather
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Feb 28, 2025
1 parent feecd71 commit 22e5ba9
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
1 change: 1 addition & 0 deletions deepspeed/runtime/torch_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def _validate_auto_cast_settings(engine):

assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16"
assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16"
assert not engine.zero_quantized_weights(), "Cannot enable both torch autocast and zero quantized weights"

assert all(p.dtype == torch.float32
for p in engine.parameters()), "All parameters must be float32 for torch autocast"
Expand Down
66 changes: 50 additions & 16 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,17 +631,26 @@ def restore_init_context():

class AllGatherHandle:

def __init__(self, handle, param: Parameter, quantization=None) -> None:
def __init__(self, handle, param: Parameter, param_buffer=None, quantization=None) -> None:
if param.ds_status != ZeroParamStatus.INFLIGHT:
raise RuntimeError(f"expected param {param.ds_summary()} to be available")

# Only one of param_buffer or quantization is provided
assert (param_buffer is None) != (quantization is None)

self.__handle = handle
self.__param = param
self.__param_buffer = param_buffer
self.__quantization = quantization

def wait(self, handle_dependency=True) -> None:
instrument_w_nvtx(self.__handle.wait)()
if self.__quantization:

if self.__param_buffer is not None:
param = self.__param
param.data = self.__param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device).to(
param.dtype)
elif self.__quantization:
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
self.__param.data = self.__quantization.backend.dequantize(
self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)
Expand Down Expand Up @@ -704,7 +713,8 @@ def wait(self, handle_dependency=True) -> None:
part_to_copy = self.partitions[rank].narrow(0, param_offset,
min(param.ds_numel - param_start, ds_tensor_numel))
partitions.append(part_to_copy)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape)
# Note that dtypes of param and partitions can be different (currently for torch.autocast support)
param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape).to(param.dtype)
param.ds_status = ZeroParamStatus.AVAILABLE
if not get_accelerator().is_synchronized_device() and handle_dependency:
for part_to_copy in partitions:
Expand Down Expand Up @@ -1161,7 +1171,14 @@ def all_gather(param_list=None, async_op=False, hierarchy=0):
param_list = [cls]
return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy)

def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group):
def _all_gather_dtype(params, world_size, rank_in_group, ds_process_group, allgather_dtype=None):
# make sure all params have the same dtype
dtype = params[0].dtype # we assume len(params) > 0
assert all(p.dtype == dtype for p in params), "all params must have the same dtype"

if allgather_dtype is None:
allgather_dtype = dtype

partition_sz = sum(p.ds_tensor.ds_numel for p in params)

use_secondary_tensor = params[0].ds_secondary_tensor is not None
Expand All @@ -1170,7 +1187,7 @@ def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group
partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params)

flat_tensor = torch.empty(partition_sz * world_size,
dtype=dtype,
dtype=allgather_dtype,
device=get_accelerator().current_device_name(),
requires_grad=False)

Expand All @@ -1179,12 +1196,15 @@ def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group
partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz))

if use_secondary_tensor:
instrument_w_nvtx(
torch.cat)([p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params],
out=partitions[rank_in_group])
else:
instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params],
instrument_w_nvtx(torch.cat)([
p.ds_secondary_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype)
for p in params
],
out=partitions[rank_in_group])
else:
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype) for p in params],
out=partitions[rank_in_group])
handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group)
#Fix get_partition_dp_group(params[0]))

Expand Down Expand Up @@ -1251,20 +1271,28 @@ def all_gather_coalesced(params: Iterable[Parameter],
buffer_size = param.ds_secondary_tensor.shape[0] * world_size #make sure out is appropriately sized

param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor

if quantize:
allgather_dtype = torch.int8
elif hasattr(param, "autocast_dtype"):
allgather_dtype = param.autocast_dtype
else:
allgather_dtype = param_ds_tensor.dtype

param_buffer = torch.empty(
buffer_size,
dtype=param_ds_tensor.dtype if not quantize else torch.int8,
dtype=allgather_dtype,
device=get_accelerator().current_device_name(),
requires_grad=False,
)
if not quantize:
# This allgather is async
handles = _dist_allgather_fn(
param_ds_tensor.to(get_accelerator().current_device_name()),
param_ds_tensor.to(get_accelerator().current_device_name()).to(allgather_dtype),
param_buffer,
ds_process_group,
)
param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device)
return AllGatherHandle(handles, param)
return AllGatherHandle(handles, param, param_buffer=param_buffer)
else:
if hasattr(param_ds_tensor, "ds_quant_scale"):
scales = param_ds_tensor.ds_quant_scale
Expand Down Expand Up @@ -1292,6 +1320,7 @@ def all_gather_coalesced(params: Iterable[Parameter],

else:
if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor:

# Use all_reduce instead of all_gather to fetch the module params
flat_buffer_size = sum(p.ds_numel_aligned for p in params)
flat_tensor = torch.zeros(flat_buffer_size,
Expand All @@ -1313,11 +1342,16 @@ def all_gather_coalesced(params: Iterable[Parameter],
if not quantize:
dtype_params = defaultdict(list)
for p in params:
dtype_params[p.ds_tensor.dtype].append(p)
allgather_dtype = p.autocast_dtype if hasattr(p, "autocast_dtype") else p.ds_tensor.dtype
dtype_params[allgather_dtype].append(p)
handles = []
for dtype, params in dtype_params.items():
handles.append(
_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group))
_all_gather_dtype(params,
world_size,
rank_in_group,
ds_process_group,
allgather_dtype=dtype))

return MultipleAllGatherHandles(handles)

Expand Down

0 comments on commit 22e5ba9

Please sign in to comment.