From 29c6bd0c95b7e82a9d4c97ed1a3d393cd0b378d4 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Fri, 31 Jan 2025 21:41:16 +0000 Subject: [PATCH 1/5] add autocast support and ds_config item --- deepspeed/runtime/config.py | 20 ++++++++++++++++++++ deepspeed/runtime/constants.py | 17 +++++++++++++++++ deepspeed/runtime/engine.py | 16 +++++++++++++++- 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index fb786f29722d..7549e1bd8c87 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -156,6 +156,24 @@ def get_amp_params(param_dict): return False +def get_torch_autocast_enabled(param_dict): + if TORCH_AUTOCAST in param_dict.keys(): + return get_scalar_param(param_dict[TORCH_AUTOCAST], TORCH_AUTOCAST_ENABLED, TORCH_AUTOCAST_ENABLED_DEFAULT) + else: + return False + + +def get_torch_autocast_dtype(param_dict): + if TORCH_AUTOCAST in param_dict: + if TORCH_AUTOCAST_DTYPE in param_dict[TORCH_AUTOCAST]: + try: + return DtypeEnum(param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]).value + except KeyError: + raise ValueError( + f"Invalid dtype for torch autocast: {param_dict[TORCH_AUTOCAST][TORCH_AUTOCAST_DTYPE]}") + return None + + def get_fp16_enabled(param_dict): if FP16 in param_dict.keys(): return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) @@ -835,6 +853,8 @@ def _initialize_params(self, param_dict): self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict) self.amp_enabled = get_amp_enabled(param_dict) self.amp_params = get_amp_params(param_dict) + self.torch_autocast_enabled = get_torch_autocast_enabled(param_dict) + self.torch_autocast_dtype = get_torch_autocast_dtype(param_dict) self.loss_scale = get_loss_scale(param_dict) self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict) self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 55cfa8f59c91..811aa4f912bc 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -202,6 +202,23 @@ AMP_ENABLED = "enabled" AMP_ENABLED_DEFAULT = False +######################################### +# Torch AMP support +######################################### +TORCH_AUTOCAST_FORMAT = ''' +PyTorch autocast config should be of the format: +"torch_autocast": { + "enabled": true, + "dtype": "bfloat16", +} +''' +TORCH_AUTOCAST = "torch_autocast" + +TORCH_AUTOCAST_ENABLED = "enabled" +TORCH_AUTOCAST_ENABLED_DEFAULT = False +TORCH_AUTOCAST_DTYPE = "dtype" +TORCH_AUTOCAST_DTYPE_DEFAULT = None + ######################################### # Gradient clipping ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 97d2afb8b723..900f74cb6909 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -374,6 +374,11 @@ def __init__(self, self._is_compiled = False + # Verify autocast setting + if self.torch_autocast_enabled(): + assert not self.fp16_enabled(), "Cannot enable both torch autocast and fp16" + assert not self.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16" + def _optimized_linear_offload_setup(self): self.optimized_linear_base_weight_sharding = False self.optimized_linear_lora_enabled = False @@ -850,6 +855,12 @@ def amp_enabled(self): def amp_params(self): return self._config.amp_params + def torch_autocast_enabled(self): + return self._config.torch_autocast_enabled + + def torch_autocast_dtype(self): + return self._config.torch_autocast_dtype + def fp16_auto_cast(self): return self._config.fp16_auto_cast @@ -1909,7 +1920,10 @@ def forward(self, *inputs, **kwargs): if self.fp16_auto_cast(): inputs = self._cast_inputs_half(inputs) - loss = self.module(*inputs, **kwargs) + with torch.autocast(device_type=get_accelerator().device_name(), + dtype=self.torch_autocast_dtype(), + enabled=self.torch_autocast_enabled()): + loss = self.module(*inputs, **kwargs) if self.zero_optimization_partition_weights(): # Disable automated discovery of external parameters From c5400cbca67a5b19f71b15653542e48f8bc14c41 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sat, 1 Feb 2025 01:15:21 +0000 Subject: [PATCH 2/5] prepare ipg buckets for multiple dtypes --- deepspeed/runtime/torch_autocast.py | 43 +++++ deepspeed/runtime/zero/stage_1_and_2.py | 203 ++++++++++++++---------- 2 files changed, 158 insertions(+), 88 deletions(-) create mode 100644 deepspeed/runtime/torch_autocast.py diff --git a/deepspeed/runtime/torch_autocast.py b/deepspeed/runtime/torch_autocast.py new file mode 100644 index 000000000000..720980fcd32d --- /dev/null +++ b/deepspeed/runtime/torch_autocast.py @@ -0,0 +1,43 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Set + +import torch + + +LOWER_PRECISION_SAFE_MODULES = [ + "torch.nn.Linear", + "torch.nn.Conv1d", + "torch.nn.Conv2d", + "torch.nn.Conv3d", +] + + +SUPPORTED_DTYPES = { + torch.float16, + torch.bfloat16, + torch.float32 +} + + +def validate_auto_cast_settings(engine): + # Verify autocast setting + if engine.torch_autocast_enabled(): + assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16" + assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16" + + assert all(p.dtype == torch.float32 for p in engine.parameters()), "All parameters must be float32 for torch autocast" + + +def init_autocast_params(model: torch.nn.Module, dtype: torch.dtype) -> None: + for module in model.modules(): + if module.__class__.__name__ in LOWER_PRECISION_SAFE_MODULES: + for p in module.parameters(recurse=False): + p.autocast_dtype = dtype + + +def get_autocast_dtypes(params: Iterable) -> Set[torch.dtype]: + return {p.autocast_dtype if hasattr(p, "autocast_dtype") else p.dtype for p in params} diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 2bece09bffc4..d521463e3a34 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -6,11 +6,15 @@ import torch from deepspeed import comm as dist from packaging import version as pkg_version -from collections import OrderedDict +from collections import OrderedDict, defaultdict +from dataclasses import dataclass, field +from typing import List, Dict + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler +from deepspeed.runtime.torch_autocast import get_autocast_dtypes from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) from deepspeed.runtime.zero.config import ZeroStageEnum @@ -94,6 +98,24 @@ def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): return padded_tensor +@dataclass +class IPGBucket: + buffer: List[torch.Tensor] = field(default_factory=list) + params: List[torch.Tensor] = field(default_factory=list) + grads: List[torch.Tensor] = field(default_factory=list) + elements: int = 0 + index: int = 0 + has_moe_params: bool = False + + def clear(self): + self.buffer.clear() + self.params.clear() + self.grads.clear() + self.elements = 0 + self.index = 0 + self.has_moe_params = False + + class DeepSpeedZeroOptimizer(ZeROOptimizer): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -229,7 +251,7 @@ def __init__(self, self.ignore_unused_parameters = ignore_unused_parameters self.round_robin_gradients = round_robin_gradients - self.extra_large_param_to_reduce = None + self.extra_large_param_to_reduce: Dict[int, torch.Tensor] = {} self.fp16_master_weights_and_gradients = fp16_master_weights_and_gradients if self.fp16_master_weights_and_gradients: @@ -437,13 +459,14 @@ def __init__(self, # map between param_id and bool to specify if a param is in this partition self.is_param_in_current_partition = {} - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 + self.ipg_buckets: Dict[torch.dtype, IPGBucket] = { + dtype: IPGBucket() + for dtype in get_autocast_dtypes([p for params in self.bit16_groups for p in params]) + } + self.params_already_reduced = [] self._release_ipg_buffers() - self.previous_reduced_grads = None - self.ipg_bucket_has_moe_params = False + self.previous_reduced_grads: Dict[int, List[torch.Tensor]] = defaultdict(list) # simplified param id self.param_id = {} @@ -675,7 +698,9 @@ def _round_robin_reorder(self, tensor_list, num_partitions): def _release_ipg_buffers(self): if self.contiguous_gradients: - self.ipg_buffer = None + for bucket in self.ipg_buckets.values(): + bucket.clear() + self.grads_in_partition = None self.grads_in_partition_offset = 0 @@ -709,12 +734,12 @@ def reduce_gradients(self, pipeline_parallel=False): # with PP we must create ipg buffer, since backward is handled outside zero if pipeline_parallel and self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_0) - self.ipg_index = 0 + for dtype, bucket in self.ipg_buckets.items(): + bucket.buffer.append( + torch.empty(int(self.reduce_bucket_size), + dtype=dtype, + device=get_accelerator().current_device_name())) + bucket.index = 0 if not self.overlap_comm: for i, group in enumerate(self.bit16_groups): @@ -923,12 +948,16 @@ def get_param_id(self, param): unique_id = id(param) return self.param_id[unique_id] - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}" - ) + def report_ipg_memory_usage(self, tag, param_elems, dtype=None): + dtypes = self.ipg_buckets.keys() if dtype is None else [dtype] + + for dt in dtypes: + bucket = self.ipg_buckets[dt] + elem_count = bucket.elements + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {dt} {bucket.elements} param {param_elems} max_percent {percent_of_bucket_size}" + ) # create a flat tensor aligned at the alignment boundary def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=False): @@ -939,13 +968,14 @@ def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=Fal def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc = self.get_gradient_for_reduction(param) - if self.elements_in_ipg_bucket + param.numel() > self.reduce_bucket_size: - self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel()) + bucket = self.ipg_buckets[param.dtype] + if bucket.elements + param.numel() > self.reduce_bucket_size: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel(), param.dtype) self.reduce_ipg_grads() if self.contiguous_gradients and self.overlap_comm: # Swap ipg_index between 0 and 1 - self.ipg_index = 1 - self.ipg_index - self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel()) + bucket.index = 1 - bucket.index + self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", param.numel(), param.dtype) param_id = self.get_param_id(param) assert self.params_already_reduced[param_id] == False, \ @@ -955,23 +985,23 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): if self.contiguous_gradients: if param.numel() > self.reduce_bucket_size: - self.extra_large_param_to_reduce = param + self.extra_large_param_to_reduce[param.dtype] = param else: # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening - new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(0, self.elements_in_ipg_bucket, param.numel()) + new_grad_tensor = bucket.buffer[bucket.index].narrow(0, bucket.elements, param.numel()) new_grad_tensor.copy_(grad_reduc.view(-1)) grad_reduc.data = new_grad_tensor.data.view_as(grad_reduc) - self.elements_in_ipg_bucket += param.numel() + bucket.elements += param.numel() assert grad_reduc is not None, f"rank {dist.get_rank()} - Invalid to reduce Param {param_id} with None gradient" - self.grads_in_ipg_bucket.append(grad_reduc) - self.params_in_ipg_bucket.append((i, param.param_idx_in_group, param_id)) + bucket.grads.append(grad_reduc) + bucket.params.append((i, param.param_idx_in_group, param_id)) #make sure the average tensor function knows how to average the gradients if is_moe_param(param): - self.ipg_bucket_has_moe_params = True + bucket.has_moe_params = True self.report_ipg_memory_usage("End ipg_remove_grads", 0) @@ -1073,12 +1103,13 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - for i, param_idx_in_group, param_id in self.params_in_ipg_bucket: + bucket = self.ipg_buckets[tensor.dtype] + for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] process_group = self.dp_process_group - if self.ipg_bucket_has_moe_params: + if bucket.has_moe_params: process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( param) else self.dp_process_group @@ -1363,21 +1394,21 @@ def copy_grads_in_partition(self, param): self.grads_in_partition_offset += param.numel() def reduce_ipg_grads(self): - if self.contiguous_gradients: - if self.extra_large_param_to_reduce is not None: - assert len(self.params_in_ipg_bucket) == 1, "more than 1 param in ipg bucket, this shouldn't happen" - _, _, param_id = self.params_in_ipg_bucket[0] - assert self.get_param_id(self.extra_large_param_to_reduce - ) == param_id, "param in ipg bucket does not match extra-large param" - extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce) - self.average_tensor(extra_large_grad_reduc.view(-1)) - self.extra_large_param_to_reduce = None + for dtype, bucket in self.ipg_buckets.items(): + + if self.contiguous_gradients: + if dtype in self.extra_large_param_to_reduce: + assert len(bucket.params) == 1, "more than 1 param in ipg bucket, this shouldn't happen" + _, _, param_id = self.params[0] + assert self.get_param_id(self.extra_large_param_to_reduce + ) == param_id, "param in ipg bucket does not match extra-large param" + extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce) + self.average_tensor(extra_large_grad_reduc.view(-1)) + del self.extra_large_param_to_reduce[dtype] + else: + self.average_tensor(bucket.buffer[bucket.ipg_index].narrow(0, 0, bucket.elements)) else: - self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket)) - else: - self.buffered_reduce_fallback(None, - self.grads_in_ipg_bucket, - elements_per_buffer=self.elements_in_ipg_bucket) + self.buffered_reduce_fallback(None, bucket.grads, elements_per_buffer=bucket.elements) if self.overlap_comm: stream = self.reduction_stream @@ -1390,35 +1421,30 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for group_idx, param_idx_in_group, param_id in self.params_in_ipg_bucket: - param = self.bit16_groups[group_idx][param_idx_in_group] - - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - self.params_already_reduced[param_id] = True - if self.partition_gradients: - if not self.is_param_in_current_partition[param_id]: - if self.overlap_comm and self.contiguous_gradients is False: - # Clear grads of other partitions during the next reduction - # to avoid clearing them before the reduction is complete. - if self.previous_reduced_grads is None: - self.previous_reduced_grads = [] - self.previous_reduced_grads.append(param) - else: - self.clear_grad_attribute(param) - elif self.contiguous_gradients: - self.copy_grads_in_partition(param) - else: # zero stage 1 - partition only optimizer state - if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: - self.copy_grads_in_partition(param) - - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.ipg_bucket_has_moe_params = False - self.elements_in_ipg_bucket = 0 + for dtype, bucket in self.ipg_buckets.items(): + for group_idx, param_idx_in_group, param_id in bucket.params: + param = self.bit16_groups[group_idx][param_idx_in_group] + + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + self.params_already_reduced[param_id] = True + if self.partition_gradients: + if not self.is_param_in_current_partition[param_id]: + if self.overlap_comm and self.contiguous_gradients is False: + # Clear grads of other partitions during the next reduction + # to avoid clearing them before the reduction is complete. + self.previous_reduced_grads[dtype].append(param) + else: + self.clear_grad_attribute(param) + elif self.contiguous_gradients: + self.copy_grads_in_partition(param) + else: # zero stage 1 - partition only optimizer state + if self.contiguous_gradients and self.is_param_in_current_partition[param_id]: + self.copy_grads_in_partition(param) + bucket.clear() ##################################################################### def reduce_ready_partitions_and_remove_grads(self, param, i): @@ -1519,10 +1545,10 @@ def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_gro return tensor def _clear_previous_reduced_grads(self): - if self.previous_reduced_grads is not None: - for param in self.previous_reduced_grads: + for dtype in self.previous_reduced_grads: + for param in self.previous_reduced_grads[dtype]: self.clear_grad_attribute(param) - self.previous_reduced_grads = None + self.previous_reduced_grads[dtype].clear() # if rank is specified do a reduction instead of an allreduce def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): @@ -2046,19 +2072,20 @@ def backward(self, loss, retain_graph=False): self.micro_step_id += 1 if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_0) + for dtype, bucket in self.ipg_buckets.items(): + buf_0 = torch.empty(int(self.reduce_bucket_size), + dtype=dtype, + device=get_accelerator().current_device_name()) + bucket.buffer.append(buf_0) + bucket.ipg_index = 0 # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: - buf_1 = torch.empty(int(self.reduce_bucket_size), - dtype=self.dtype, - device=get_accelerator().current_device_name()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 + for dtype, bucket in self.ipg_buckets.items(): + buf_1 = torch.empty(int(self.reduce_bucket_size), + dtype=dtype, + device=get_accelerator().current_device_name()) + bucket.buffer.append(buf_1) if self.custom_loss_scaler: scaled_loss = self.external_loss_scale * loss From 96573b631bf6c63955d7b769fd2fd4f7d04611a6 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Sun, 2 Feb 2025 02:32:54 +0000 Subject: [PATCH 3/5] switch communication data type --- deepspeed/runtime/engine.py | 9 +- deepspeed/runtime/torch_autocast.py | 50 ++++++---- deepspeed/runtime/zero/stage_1_and_2.py | 117 +++++++++++++++++------- 3 files changed, 116 insertions(+), 60 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 900f74cb6909..a9d159673010 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -90,6 +90,7 @@ from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint +from deepspeed.runtime.torch_autocast import init_autocast_params from .pipe.module import PipelineModule from .utils import get_ma_status @@ -311,6 +312,9 @@ def __init__(self, if not isinstance(model_parameters, list): model_parameters = list(model_parameters) + if self.torch_autocast_enabled(): + init_autocast_params(self, self.torch_autocast_dtype()) + if has_optimizer: self._configure_optimizer(optimizer, model_parameters) self._configure_lr_scheduler() @@ -374,11 +378,6 @@ def __init__(self, self._is_compiled = False - # Verify autocast setting - if self.torch_autocast_enabled(): - assert not self.fp16_enabled(), "Cannot enable both torch autocast and fp16" - assert not self.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16" - def _optimized_linear_offload_setup(self): self.optimized_linear_base_weight_sharding = False self.optimized_linear_lora_enabled = False diff --git a/deepspeed/runtime/torch_autocast.py b/deepspeed/runtime/torch_autocast.py index 720980fcd32d..cd9f8a7dca39 100644 --- a/deepspeed/runtime/torch_autocast.py +++ b/deepspeed/runtime/torch_autocast.py @@ -7,37 +7,47 @@ import torch - LOWER_PRECISION_SAFE_MODULES = [ - "torch.nn.Linear", - "torch.nn.Conv1d", - "torch.nn.Conv2d", - "torch.nn.Conv3d", + torch.nn.Linear, + torch.nn.Conv1d, + torch.nn.Conv2d, + torch.nn.Conv3d, ] +TORCH_AUTOCAST_INITIALIZED = False + -SUPPORTED_DTYPES = { - torch.float16, - torch.bfloat16, - torch.float32 -} +def _validate_auto_cast_settings(engine): + assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16" + assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16" -def validate_auto_cast_settings(engine): - # Verify autocast setting - if engine.torch_autocast_enabled(): - assert not engine.fp16_enabled(), "Cannot enable both torch autocast and fp16" - assert not engine.bfloat16_enabled(), "Cannot enable both torch autocast and bfloat16" + assert all(p.dtype == torch.float32 + for p in engine.parameters()), "All parameters must be float32 for torch autocast" + assert engine.communication_data_type == torch.float32, "Communication data type must be float32 for torch autocast" - assert all(p.dtype == torch.float32 for p in engine.parameters()), "All parameters must be float32 for torch autocast" +def init_autocast_params(engine, dtype: torch.dtype) -> None: + + _validate_auto_cast_settings(engine) + model = engine.module -def init_autocast_params(model: torch.nn.Module, dtype: torch.dtype) -> None: for module in model.modules(): - if module.__class__.__name__ in LOWER_PRECISION_SAFE_MODULES: + if module.__class__ in LOWER_PRECISION_SAFE_MODULES: for p in module.parameters(recurse=False): p.autocast_dtype = dtype + global TORCH_AUTOCAST_INITIALIZED + TORCH_AUTOCAST_INITIALIZED = True + + +def is_autocast_initialized() -> bool: + return TORCH_AUTOCAST_INITIALIZED + + +def get_autocast_dtype(param: torch.nn.Parameter) -> torch.dtype: + return param.autocast_dtype if hasattr(param, "autocast_dtype") else param.dtype + -def get_autocast_dtypes(params: Iterable) -> Set[torch.dtype]: - return {p.autocast_dtype if hasattr(p, "autocast_dtype") else p.dtype for p in params} +def get_all_autocast_dtypes(params: Iterable) -> Set[torch.dtype]: + return {get_autocast_dtype(p) for p in params} diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index d521463e3a34..4108e7808b1d 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -14,7 +14,7 @@ from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler -from deepspeed.runtime.torch_autocast import get_autocast_dtypes +from deepspeed.runtime.torch_autocast import get_autocast_dtype, get_all_autocast_dtypes, is_autocast_initialized from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter, align_dense_tensors, all_gather_dp_groups) from deepspeed.runtime.zero.config import ZeroStageEnum @@ -459,10 +459,9 @@ def __init__(self, # map between param_id and bool to specify if a param is in this partition self.is_param_in_current_partition = {} - self.ipg_buckets: Dict[torch.dtype, IPGBucket] = { - dtype: IPGBucket() - for dtype in get_autocast_dtypes([p for params in self.bit16_groups for p in params]) - } + comm_dtypes = get_all_autocast_dtypes([p for params in self.bit16_groups for p in params + ]) if is_autocast_initialized() else {self.communication_data_type} + self.ipg_buckets: Dict[torch.dtype, IPGBucket] = {dtype: IPGBucket() for dtype in comm_dtypes} self.params_already_reduced = [] self._release_ipg_buffers() @@ -968,7 +967,7 @@ def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=Fal def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): grad_reduc = self.get_gradient_for_reduction(param) - bucket = self.ipg_buckets[param.dtype] + bucket = self.ipg_buckets[get_autocast_dtype(param)] if bucket.elements + param.numel() > self.reduce_bucket_size: self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.numel(), param.dtype) self.reduce_ipg_grads() @@ -1009,7 +1008,7 @@ def print_rank_0(self, message): if dist.get_rank() == 0: logger.info(message) - def gradient_reduction_w_predivide(self, tensor): + def gradient_reduction_w_predivide(self, tensor, communication_data_type: torch.dtype): if tensor.size().numel() == 0: return tensor @@ -1017,8 +1016,8 @@ def gradient_reduction_w_predivide(self, tensor): tensor_to_allreduce = tensor - if self.communication_data_type != tensor.dtype: - tensor_to_allreduce = tensor.to(self.communication_data_type) + if communication_data_type != tensor.dtype: + tensor_to_allreduce = tensor.to(communication_data_type) if self.postscale_gradients: if self.gradient_predivide_factor != 1.0: @@ -1033,24 +1032,35 @@ def gradient_reduction_w_predivide(self, tensor): tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: tensor.copy_(tensor_to_allreduce) return tensor def allreduce_and_copy_with_multiple_ranks(self, small_bucket, + communication_data_type: torch.dtype, log=None, divide=True, process_group=None, bucket_ranks=None): process_group = self.dp_process_group if process_group is None else process_group - allreduced = self.allreduce_bucket(small_bucket, log=log, divide=divide, process_group=process_group) + allreduced = self.allreduce_bucket(small_bucket, + communication_data_type, + log=log, + divide=divide, + process_group=process_group) for buf, synced, bucket_rank in zip(small_bucket, self.unflatten(allreduced, small_bucket), bucket_ranks): if dist.get_rank(group=process_group) == bucket_rank: buf.copy_(synced) - def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, divide=True, process_group=None): + def allreduce_and_scatter(self, + bucket, + communication_data_type: torch.dtype, + numel_per_bucket=500000000, + log=None, + divide=True, + process_group=None): small_bucket = [] small_bucket_ranks = [] numel = 0 @@ -1063,6 +1073,7 @@ def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, di numel = numel + tensor.numel() if numel > numel_per_bucket: self.allreduce_and_copy_with_multiple_ranks(small_bucket, + communication_data_type, log=None, divide=divide, process_group=process_group, @@ -1073,12 +1084,13 @@ def allreduce_and_scatter(self, bucket, numel_per_bucket=500000000, log=None, di if len(small_bucket) > 0: self.allreduce_and_copy_with_multiple_ranks(small_bucket, + communication_data_type, log=None, divide=divide, process_group=process_group, bucket_ranks=small_bucket_ranks) - def average_tensor(self, tensor): + def average_tensor(self, tensor: torch.Tensor, communication_data_type: torch.dtype): if self.overlap_comm: stream = self.reduction_stream if not get_accelerator().resolves_data_dependency(): @@ -1089,7 +1101,7 @@ def average_tensor(self, tensor): with get_accelerator().stream(stream): if not self.reduce_scatter: - self.gradient_reduction_w_predivide(tensor) + self.gradient_reduction_w_predivide(tensor, communication_data_type) return # Accumulate destination ranks and bucket offsets for each gradient slice. @@ -1103,7 +1115,7 @@ def average_tensor(self, tensor): process_group = self.dp_process_group # count = 0 - bucket = self.ipg_buckets[tensor.dtype] + bucket = self.ipg_buckets[communication_data_type] for i, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[i][param_idx_in_group] @@ -1167,12 +1179,14 @@ def average_tensor(self, tensor): for bucket_key in buckets: if self.use_multi_rank_bucket_allreduce: self.allreduce_and_scatter(buckets[bucket_key], + communication_data_type, numel_per_bucket=self.reduce_bucket_size, divide=False, process_group=bucket_key) else: dst, process_group = bucket_key self.allreduce_no_retain(buckets[bucket_key], + communication_data_type, numel_per_bucket=self.reduce_bucket_size, rank=dst, divide=False, @@ -1394,21 +1408,21 @@ def copy_grads_in_partition(self, param): self.grads_in_partition_offset += param.numel() def reduce_ipg_grads(self): - for dtype, bucket in self.ipg_buckets.items(): + for comm_dtype, bucket in self.ipg_buckets.items(): if self.contiguous_gradients: - if dtype in self.extra_large_param_to_reduce: + if comm_dtype in self.extra_large_param_to_reduce: assert len(bucket.params) == 1, "more than 1 param in ipg bucket, this shouldn't happen" _, _, param_id = self.params[0] assert self.get_param_id(self.extra_large_param_to_reduce ) == param_id, "param in ipg bucket does not match extra-large param" extra_large_grad_reduc = self.get_gradient_for_reduction(self.extra_large_param_to_reduce) - self.average_tensor(extra_large_grad_reduc.view(-1)) - del self.extra_large_param_to_reduce[dtype] + self.average_tensor(extra_large_grad_reduc.view(-1), comm_dtype) + del self.extra_large_param_to_reduce[comm_dtype] else: - self.average_tensor(bucket.buffer[bucket.ipg_index].narrow(0, 0, bucket.elements)) + self.average_tensor(bucket.buffer[bucket.ipg_index].narrow(0, 0, bucket.elements), comm_dtype) else: - self.buffered_reduce_fallback(None, bucket.grads, elements_per_buffer=bucket.elements) + self.buffered_reduce_fallback(None, bucket.grads, comm_dtype, elements_per_buffer=bucket.elements) if self.overlap_comm: stream = self.reduction_stream @@ -1421,7 +1435,7 @@ def reduce_ipg_grads(self): stream = get_accelerator().current_stream() with get_accelerator().stream(stream): - for dtype, bucket in self.ipg_buckets.items(): + for comm_dtype, bucket in self.ipg_buckets.items(): for group_idx, param_idx_in_group, param_id in bucket.params: param = self.bit16_groups[group_idx][param_idx_in_group] @@ -1436,7 +1450,7 @@ def reduce_ipg_grads(self): if self.overlap_comm and self.contiguous_gradients is False: # Clear grads of other partitions during the next reduction # to avoid clearing them before the reduction is complete. - self.previous_reduced_grads[dtype].append(param) + self.previous_reduced_grads[comm_dtype].append(param) else: self.clear_grad_attribute(param) elif self.contiguous_gradients: @@ -1513,7 +1527,14 @@ def set_none_gradients_to_zero(self, i, partition_id): param.grad = torch.zeros_like(param) ######################Reduction Related Methods############################## - def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_group=None): + def allreduce_bucket(self, + bucket, + communication_data_type: torch.dtype, + rank=None, + log=None, + divide=True, + process_group=None): + tensor = self.flatten(bucket) process_group = self.dp_process_group if process_group is None else process_group @@ -1522,8 +1543,6 @@ def allreduce_bucket(self, bucket, rank=None, log=None, divide=True, process_gro if pg_correctness_test or self.sequence_parallel_size > 1: communication_data_type = torch.float32 - else: - communication_data_type = self.communication_data_type if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) @@ -1551,7 +1570,13 @@ def _clear_previous_reduced_grads(self): self.previous_reduced_grads[dtype].clear() # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, process_group=None): + def allreduce_and_copy(self, + small_bucket, + communication_data_type: torch.dtype, + rank=None, + log=None, + divide=True, + process_group=None): process_group = self.dp_process_group if process_group is None else process_group if self.overlap_comm: if not get_accelerator().resolves_data_dependency(): @@ -1564,6 +1589,7 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, pro with get_accelerator().stream(stream): allreduced = self.allreduce_bucket( + communication_data_type, small_bucket, rank=rank, log=log, @@ -1577,6 +1603,7 @@ def allreduce_and_copy(self, small_bucket, rank=None, log=None, divide=True, pro def allreduce_no_retain( self, bucket, + communication_data_type: torch.dtype, numel_per_bucket=500000000, rank=None, log=None, @@ -1589,20 +1616,39 @@ def allreduce_no_retain( small_bucket.append(tensor) numel = numel + tensor.numel() if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None, divide=divide, process_group=process_group) + self.allreduce_and_copy(small_bucket, + communication_data_type, + rank=rank, + log=None, + divide=divide, + process_group=process_group) small_bucket = [] numel = 0 if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log, divide=divide, process_group=process_group) + self.allreduce_and_copy(small_bucket, + communication_data_type, + rank=rank, + log=log, + divide=divide, + process_group=process_group) # allows using reduction of gradients instead of using all_reduce - def buffered_reduce_fallback(self, rank, grads, elements_per_buffer=500000000, log=None): + def buffered_reduce_fallback(self, + rank, + grads, + communication_data_type: torch.dtype, + elements_per_buffer=500000000, + log=None): split_buckets = split_half_float_double(grads) for i, bucket in enumerate(split_buckets): - self.allreduce_no_retain(bucket, numel_per_bucket=elements_per_buffer, rank=rank, log=log) + self.allreduce_no_retain(bucket, + communication_data_type, + numel_per_bucket=elements_per_buffer, + rank=rank, + log=log) ############################################################################# ############################################################################# @@ -2072,18 +2118,19 @@ def backward(self, loss, retain_graph=False): self.micro_step_id += 1 if self.contiguous_gradients: - for dtype, bucket in self.ipg_buckets.items(): + for _, bucket in self.ipg_buckets.items(): + # Buffer's dtype is the same as the dtype of optimizer, not dtype for autocast buf_0 = torch.empty(int(self.reduce_bucket_size), - dtype=dtype, + dtype=self.dtype, device=get_accelerator().current_device_name()) bucket.buffer.append(buf_0) bucket.ipg_index = 0 # Use double buffers to avoid data access conflict when overlap_comm is enabled. if self.overlap_comm: - for dtype, bucket in self.ipg_buckets.items(): + for _, bucket in self.ipg_buckets.items(): buf_1 = torch.empty(int(self.reduce_bucket_size), - dtype=dtype, + dtype=self.dtype, device=get_accelerator().current_device_name()) bucket.buffer.append(buf_1) From 0e52edb6a3cc940a9d8c53f7343f3473ebd69bbb Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 3 Feb 2025 07:08:15 +0000 Subject: [PATCH 4/5] add gradscaler --- deepspeed/runtime/zero/stage_1_and_2.py | 17 ++- tests/unit/common.py | 54 ++++++++ tests/unit/runtime/compile/util.py | 56 +-------- tests/unit/runtime/zero/test_zero_autocast.py | 118 ++++++++++++++++++ 4 files changed, 187 insertions(+), 58 deletions(-) create mode 100644 tests/unit/runtime/zero/test_zero_autocast.py diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 4108e7808b1d..00421dc23548 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -459,8 +459,13 @@ def __init__(self, # map between param_id and bool to specify if a param is in this partition self.is_param_in_current_partition = {} - comm_dtypes = get_all_autocast_dtypes([p for params in self.bit16_groups for p in params - ]) if is_autocast_initialized() else {self.communication_data_type} + if is_autocast_initialized(): + comm_dtypes = get_all_autocast_dtypes([p for params in self.bit16_groups for p in params]) + self.torch_autocast_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) + else: + comm_dtypes = {self.communication_data_type} + self.torch_autocast_gradscaler = None + self.ipg_buckets: Dict[torch.dtype, IPGBucket] = {dtype: IPGBucket() for dtype in comm_dtypes} self.params_already_reduced = [] @@ -1889,7 +1894,11 @@ def _optimizer_step(self, group_no): # self.optimizer.step(fp16_param_groups=[self.get_bit16_param_group(group_no)]) #else: # self.optimizer.step() - self.optimizer.step() + if self.torch_autocast_gradscaler: + self.torch_autocast_gradscaler.step(self.optimizer) + self.torch_autocast_gradscaler.update() + else: + self.optimizer.step() self.optimizer.param_groups = original_param_groups # We need to link optimizer state after the first step() call @@ -2137,6 +2146,8 @@ def backward(self, loss, retain_graph=False): if self.custom_loss_scaler: scaled_loss = self.external_loss_scale * loss scaled_loss.backward() + elif self.torch_autocast_gradscaler: + self.torch_autocast_gradscaler.scale(loss).backward(retain_graph=retain_graph) else: self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) diff --git a/tests/unit/common.py b/tests/unit/common.py index 1498b0400ee1..9b3ad2708b59 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -11,6 +11,9 @@ import subprocess from abc import ABC, abstractmethod from pathlib import Path +import random +import numpy as np +from typing import Callable, Any import torch import torch.multiprocessing as mp @@ -505,3 +508,54 @@ def preferred_dtype(): return torch.bfloat16 else: return torch.float32 + + +class EnableDeterminism: + + def __init__(self, seed: int): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + + self.seed = seed + local_rank + self.saved_random_state = None + self.saved_np_random_state = None + self.saved_cuda_launch_blocking = None + self.saved_cublas_workspace_config = None + self.saved_deterministic_algorithms = None + + def __enter__(self): + self.saved_random_state = random.getstate() + self.saved_np_random_state = np.random.get_state() + self.saved_acc_rng_state = get_accelerator().get_rng_state() + self.saved_cuda_launch_blocking = os.environ.get("CUDA_LAUNCH_BLOCKING", "") + self.saved_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") + self.saved_deterministic_algorithms = torch.are_deterministic_algorithms_enabled() + + random.seed(self.seed) + np.random.seed(self.seed) + get_accelerator().manual_seed(self.seed) + get_accelerator().manual_seed_all(self.seed) + + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + def __exit__(self, type, value, traceback): + random.setstate(self.saved_random_state) + np.random.set_state(self.saved_np_random_state) + get_accelerator().set_rng_state(self.saved_acc_rng_state) + os.environ["CUDA_LAUNCH_BLOCKING"] = self.saved_cuda_launch_blocking + os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.saved_cublas_workspace_config + torch.use_deterministic_algorithms(self.saved_deterministic_algorithms) + + +def enable_determinism(seed: int): + + def decorator(func: Callable) -> Callable: + + def wrapper(*args: Any, **kwargs: Any): + with EnableDeterminism(seed): + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/unit/runtime/compile/util.py b/tests/unit/runtime/compile/util.py index d53886a81429..48a711b23e53 100644 --- a/tests/unit/runtime/compile/util.py +++ b/tests/unit/runtime/compile/util.py @@ -3,9 +3,6 @@ # DeepSpeed Team -import random -import os -import numpy as np from copy import deepcopy import torch @@ -15,58 +12,7 @@ from deepspeed.runtime.zero import GatheredParameters from unit.simple_model import SimpleModel -from typing import Callable, Any - - -class EnableDeterminism: - - def __init__(self, seed: int): - local_rank = int(os.getenv("LOCAL_RANK", "0")) - - self.seed = seed + local_rank - self.saved_random_state = None - self.saved_np_random_state = None - self.saved_cuda_launch_blocking = None - self.saved_cublas_workspace_config = None - self.saved_deterministic_algorithms = None - - def __enter__(self): - self.saved_random_state = random.getstate() - self.saved_np_random_state = np.random.get_state() - self.saved_acc_rng_state = get_accelerator().get_rng_state() - self.saved_cuda_launch_blocking = os.environ.get("CUDA_LAUNCH_BLOCKING", "") - self.saved_cublas_workspace_config = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "") - self.saved_deterministic_algorithms = torch.are_deterministic_algorithms_enabled() - - random.seed(self.seed) - np.random.seed(self.seed) - get_accelerator().manual_seed(self.seed) - get_accelerator().manual_seed_all(self.seed) - - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True) - - def __exit__(self, type, value, traceback): - random.setstate(self.saved_random_state) - np.random.set_state(self.saved_np_random_state) - get_accelerator().set_rng_state(self.saved_acc_rng_state) - os.environ["CUDA_LAUNCH_BLOCKING"] = self.saved_cuda_launch_blocking - os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.saved_cublas_workspace_config - torch.use_deterministic_algorithms(self.saved_deterministic_algorithms) - - -def enable_determinism(seed: int): - - def decorator(func: Callable) -> Callable: - - def wrapper(*args: Any, **kwargs: Any): - with EnableDeterminism(seed): - return func(*args, **kwargs) - - return wrapper - - return decorator +from unit.common import enable_determinism @enable_determinism(123) diff --git a/tests/unit/runtime/zero/test_zero_autocast.py b/tests/unit/runtime/zero/test_zero_autocast.py new file mode 100644 index 000000000000..f08db93333be --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_autocast.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import random +import os +import numpy as np +from typing import Callable, Any +from copy import deepcopy + +import pytest + +import torch +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler + +from unit.common import DistributedTest, preferred_dtype, enable_determinism +from unit.simple_model import SimpleModel, random_dataloader +from unit.util import bf16_required_version_check + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero import GatheredParameters +from deepspeed.git_version_info import torch_info +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum + + +RTOL = 0.01 +ATOL = 0.0 + + +def step_amp(baseline_model, + baseline_optimizer, + target_engine, + dtype, + baseline_scaler, + x, y, + rtol, atol): + # Runs the forward pass with autocasting. + with torch.autocast(device_type="cuda", dtype=dtype): + baseline_optimizer.zero_grad() + baseline_loss = baseline_model(x, y) + + baseline_scaler.scale(baseline_loss).backward() + baseline_scaler.step(baseline_optimizer) + baseline_scaler.update() + + target_loss = target_engine(x, y) + + assert torch.allclose(baseline_loss.float(), target_loss.float(), rtol=rtol, atol=atol) + + target_engine.backward(target_loss) + target_engine.step() + + +@enable_determinism(123) +def compare_loss(zero_stage, dtype): + iteration = 5 + hidden_dim = 10 + lr = 0.001 + + if dtype == torch.bfloat16 and not bf16_required_version_check(): + raise ValueError("DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly") + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + }, + "torch_autocast": { + "enabled": True, + "dtype": str(dtype) + } + } + + model_cls = SimpleModel + model = model_cls(hidden_dim) + + deepspeed.init_distributed(dist_backend='nccl') + + i = get_accelerator().current_device() + device = get_accelerator().current_device_name() + baseline_model = DDP(deepcopy(model).to(device=device, dtype=torch.float32), device_ids=[i], output_device=i) + baseline_optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=lr, weight_decay=0.0) + baseline_scaler = torch.amp.GradScaler() + + stage_3_enabled = config_dict["zero_optimization"]["stage"] == 3 + if stage_3_enabled: + with deepspeed.zero.Init(config_dict_or_path=config_dict): + target_model = model_cls(hidden_dim) + with GatheredParameters(target_model.parameters(), modifier_rank=0): + for p1, p2 in zip(target_model.parameters(), model.parameters()): + p1.data.copy_(p2.data) + else: + target_model = deepcopy(model) + + ds_optimizer = torch.optim.Adam(target_model.parameters(), lr=lr) + target_engine, _, _, _ = deepspeed.initialize(config=config_dict, + model=target_model, + optimizer=ds_optimizer) + train_batch_size = config_dict["train_micro_batch_size_per_gpu"] + + xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=torch.float32) for _ in range(iteration)] + ys = [torch.randn_like(x) for x in xs] + + for i, (x, y) in enumerate(zip(xs, ys)): + step_amp(baseline_model, baseline_optimizer, target_engine, dtype, baseline_scaler, x, y, RTOL, ATOL) + + +@pytest.mark.parametrize("zero_stage", [1, 2]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +class TestZeroAutoCast(DistributedTest): + world_size = 2 + + def test(self, zero_stage, dtype): + compare_loss(zero_stage, dtype) \ No newline at end of file From b294eae3948edadfaa637f99b15ce6cf1d33a2a9 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Mon, 3 Feb 2025 07:31:27 +0000 Subject: [PATCH 5/5] fix import and formatting --- tests/unit/runtime/zero/test_zero_autocast.py | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/tests/unit/runtime/zero/test_zero_autocast.py b/tests/unit/runtime/zero/test_zero_autocast.py index f08db93333be..31aba0247cd8 100644 --- a/tests/unit/runtime/zero/test_zero_autocast.py +++ b/tests/unit/runtime/zero/test_zero_autocast.py @@ -3,40 +3,26 @@ # DeepSpeed Team -import random -import os -import numpy as np -from typing import Callable, Any from copy import deepcopy import pytest import torch from torch.nn.parallel import DistributedDataParallel as DDP -from torch.cuda.amp import autocast, GradScaler -from unit.common import DistributedTest, preferred_dtype, enable_determinism -from unit.simple_model import SimpleModel, random_dataloader +from unit.common import DistributedTest, enable_determinism +from unit.simple_model import SimpleModel from unit.util import bf16_required_version_check import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.runtime.zero import GatheredParameters -from deepspeed.git_version_info import torch_info -from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum - RTOL = 0.01 ATOL = 0.0 -def step_amp(baseline_model, - baseline_optimizer, - target_engine, - dtype, - baseline_scaler, - x, y, - rtol, atol): +def step_amp(baseline_model, baseline_optimizer, target_engine, dtype, baseline_scaler, x, y, rtol, atol): # Runs the forward pass with autocasting. with torch.autocast(device_type="cuda", dtype=dtype): baseline_optimizer.zero_grad() @@ -61,7 +47,9 @@ def compare_loss(zero_stage, dtype): lr = 0.001 if dtype == torch.bfloat16 and not bf16_required_version_check(): - raise ValueError("DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly") + raise ValueError( + "DeepSpeed BFloat16 tests need torch >= 1.10, NCCL >= 2.10.3, CUDA > =11.0 and HW support for BFloat16 to run correctly" + ) config_dict = { "train_micro_batch_size_per_gpu": 1, @@ -97,9 +85,7 @@ def compare_loss(zero_stage, dtype): target_model = deepcopy(model) ds_optimizer = torch.optim.Adam(target_model.parameters(), lr=lr) - target_engine, _, _, _ = deepspeed.initialize(config=config_dict, - model=target_model, - optimizer=ds_optimizer) + target_engine, _, _, _ = deepspeed.initialize(config=config_dict, model=target_model, optimizer=ds_optimizer) train_batch_size = config_dict["train_micro_batch_size_per_gpu"] xs = [torch.randn(train_batch_size, hidden_dim, device=device, dtype=torch.float32) for _ in range(iteration)] @@ -115,4 +101,4 @@ class TestZeroAutoCast(DistributedTest): world_size = 2 def test(self, zero_stage, dtype): - compare_loss(zero_stage, dtype) \ No newline at end of file + compare_loss(zero_stage, dtype)