diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index a8d15cd5332b..fd1f421b8954 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -37,7 +37,7 @@ from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError from .runtime.activation_checkpointing import checkpointing from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig -from .module_inject import replace_transformer_layer, revert_transformer_layer +from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode from .utils import log_dist, OnDevice, logger from .comm.comm import init_distributed @@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs): engine = InferenceEngine(model, config=ds_inference_config) return engine + + +def tp_model_init(model, tp_size, dtype): + """ + Initialize the model for tensor parallelism. + + Args: + model (torch.nn.Module): The model to be initialized. + tp_size (int): The tensor parallelism size. + dtype (torch.dtype): The data type to be used for the model. + + Returns: + torch.nn.Module: The initialized model with tensor parallelism. + """ + # avoid re-entry + assert not hasattr( + model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed." + + set_autotp_mode(training=True) + + from deepspeed.runtime.tensor_parallel import TpTrainingManager + # The expected usage here is for it to be invoked by transformers package. + + #TODO: We should provide a custom TP mapping solution without using autoTP + #as modifying the autoTP logic may be more difficult for users compared to configuring it + + model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module + + setattr(model, 'ds_autotp_parsed', True) + + return model diff --git a/deepspeed/comm/comm.py b/deepspeed/comm/comm.py index df8e8022081d..1aa9b135115b 100755 --- a/deepspeed/comm/comm.py +++ b/deepspeed/comm/comm.py @@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) +@timed_op +def broadcast_object_list(object_list, src, group=None, device=None): + global cdb + return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device) + + @timed_op def all_gather(tensor_list, tensor, diff --git a/deepspeed/comm/torch.py b/deepspeed/comm/torch.py index 5461ae18d1f0..efa0640fb87b 100755 --- a/deepspeed/comm/torch.py +++ b/deepspeed/comm/torch.py @@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False): else: return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op) + @disable_compiler_collective + def broadcast_object_list(self, object_list, src, group=None, device=None): + return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device) + @disable_compiler_collective def all_gather(self, tensor_list, tensor, group=None, async_op=False): if DS_COMM_ALL_GATHER_OFF: diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 3089d0c557a4..a529d9343228 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -15,7 +15,6 @@ from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine from deepspeed.utils.timer import SynchronizedWallClockTimer from deepspeed.runtime.compiler import is_compile_supported - from ..runtime.state_dict_factory import SDLoaderFactory from ..runtime.weight_quantizer import WeightQuantization from ..module_inject import replace_transformer_layer, generic_injection diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py index 4bdabf383b26..9fc2f979a04b 100755 --- a/deepspeed/module_inject/__init__.py +++ b/deepspeed/module_inject/__init__.py @@ -6,5 +6,5 @@ from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection from .module_quantize import quantize_transformer_layer from .replace_policy import HFBertLayerPolicy -from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize +from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode from .policy import DSPolicy diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 8bdcf6faa053..d148c26968b3 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -11,10 +11,12 @@ from typing import Optional import torch from deepspeed import comm as dist -from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce +from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer from deepspeed.accelerator import get_accelerator -from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp +from .fusedqkv_utils import require_tp_fused_qkvw from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list +from deepspeed.utils import groups +from deepspeed.module_inject.layers import is_autotp_training_mode def move(tensor, device, copy=True): @@ -333,10 +335,18 @@ def tp_parser(model): return policy_list def set_tensor_parallel_config(self, mp_size, mp_group): + + if is_autotp_training_mode(): + self.mp_group = groups.get_tensor_model_parallel_group() + self.mp_size = groups.get_tensor_model_parallel_world_size() + return + self.mp_size = mp_size self.mp_group = mp_group def _replace(self, child, name, conv_linear_layer): + # This function should clearly define the routing rules for specific layers + # and avoid any complex shard-related logic. if getattr(child, "replaced", False) == True: return device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name() @@ -352,14 +362,15 @@ def _replace(self, child, name, conv_linear_layer): # For Yuan model if 'Yuan' in str(self.module): if 'v_proj' in name: - weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), - dist.get_world_size(), True) - return LinearLayer(weight=weight, bias=bias) + return Yuan_LinearLayer(child, self.mp_group) + elif 'o_proj' in name: - weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(), - dist.get_world_size(), False) - return LinearAllreduce(weight, bias, self.mp_group) - # For Arctic model, bypass to all_reduce replacement for w2 weights + return Yuan_LinearAllreduce(child, self.mp_group) + + # For MLP including chunk layer. + if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): + return GateUpPack_LinearLayer(child, self.mp_group) + # For Arctic model, bypass to all_reduce replacement for w2 weights arctic_w2_all_reduce_linear = False if 'Arctic' in str(self.module) and 'w2' in name: arctic_w2_all_reduce_linear = True @@ -367,65 +378,25 @@ def _replace(self, child, name, conv_linear_layer): down_proj = False if 'down_proj' in name: down_proj = True - # For MLP including chunk layer. - if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)): - weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) - return LinearLayer(weight=weight, bias=bias) if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj: - # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] - # else [weight_shape[0], weight_shape[1] // mp_size] + setattr(child, "replaced", True) if self.conv_linear_layer: - child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - data = child.weight.data.split(get_shard_size_list( - weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name), - dim=1) - data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() - del data + return Conv_LinearALlreduce(child, self.mp_group, name=name) + elif name == "lm_head" or name == 'embed_out': + return LmHeadLinearAllreduce(child, self.mp_group) - setattr(child, "replaced", True) - if name == "lm_head" or name == 'embed_out': - return LmHeadLinearAllreduce( - torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(), - child.bias if child.bias is None else torch.nn.parameter.Parameter( - move(child.bias, device_name, return_new_copy)), self.mp_group) - return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \ - torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group) + return LinearAllreduce(child, self.mp_group, name=name) else: - # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] - # else [weight_shape[0] // mp_size, weight_shape[1]] + setattr(child, "replaced", True) if self.conv_linear_layer: - child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - - if require_tp_fused_qkvw(name, self.mp_size): + conv_LinearLayer(child, self.mp_group) + elif require_tp_fused_qkvw(name, self.mp_size): #Check and handle fused qkv for TP - #The copy is a regular copy, The shape of dst and src is the same - data_dc = move( - prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), - device_name, return_new_copy) - - bias_data_dc = None if child.bias is None else move( - prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), - device_name, return_new_copy) - else: - data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), - dim=1 if self.conv_linear_layer else 0) - data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() - del data - - if child.bias is not None: - bias_data = child.bias.data.split(get_shard_size_list( - weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), - dim=0) - bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy) - bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) - del bias_data - else: - bias_data_dc = None + return fused_LinearLayer(child, self.mp_group, fused_module=self.module) - setattr(child, "replaced", True) - return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc) + return LinearLayer(child, self.mp_group, name=name) def _slice_embedding(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: diff --git a/deepspeed/module_inject/layers.py b/deepspeed/module_inject/layers.py index 2f884ba4fb09..c410bf900c31 100644 --- a/deepspeed/module_inject/layers.py +++ b/deepspeed/module_inject/layers.py @@ -7,10 +7,578 @@ from deepspeed import comm as dist from torch import nn from torch.nn import functional as F - from torch.nn.parameter import Parameter from deepspeed.accelerator import get_accelerator from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list +from abc import ABC, abstractmethod +from typing import Iterable, Any, Optional, List, Tuple +from .fusedqkv_utils import shard_value_with_share_qk, shard_chunk_mlp, prepare_tp_fused_qkvw +from deepspeed.runtime.tensor_parallel import AUTOTP_MODE +from copy import deepcopy +from typing import Union + +DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE +DS_IS_REPLACED_MODULE = 'ds_is_replaced_module' +DS_TENSOR_MODEL_PARALLEL = 'tensor_model_parallel' + + +def get_auto_tp_mode(): + global DEEPSPEED_AUTOTP_MODE + return DEEPSPEED_AUTOTP_MODE + + +def is_autotp_training_mode(): + global DEEPSPEED_AUTOTP_MODE + return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING + + +def set_autotp_mode(training=False): + """ + Set the DEEPSPEED_AUTOTP_MODE based on the training flag + """ + global DEEPSPEED_AUTOTP_MODE + if training: + DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.TRAINING + else: + DEEPSPEED_AUTOTP_MODE = AUTOTP_MODE.INFERENCE + + +def move(tensor, device): + # TODO: consider the timing of deletion + # to save host resources when DP > 1。 + + if tensor.is_meta: + return torch.empty_like(tensor, device=device) + else: + # Using new tensors help in freeing memory (after split for example) was done before by calling clone(). + # Using copy=True instead of clone() will help in case of cpu --> cpu. + # Otherwise to() will not create a new copy for the view of the full tensor, and it will not be de-referenced. + cloned_tensor = tensor.to(device, copy=True) + + # free the memory of the original tensor to reduce memory peak + # Equivalent to directly deleting the tensor reference outside the function. + # see https://github.com/microsoft/DeepSpeed/pull/4353 + tensor.data = torch.empty(0, device=tensor.device) + return cloned_tensor + + +class RowParallel(torch.autograd.Function): + """ + A custom autograd function for performing row-wise parallelism. + """ + + @staticmethod + def symbolic(graph, input): + """Symbolic function for tracing.""" + return input + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor, is_inference_mode: bool) -> torch.Tensor: + """ + Forward pass. + """ + ctx.group = group + if group == None: + return input + if is_inference_mode: + dist.inference_all_reduce(input, group=group) + else: + dist.all_reduce(input.contiguous(), group=group) + return input + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None]: + """ + Backward pass. + """ + return None, grad_output, None + + +class ColumnParallel(torch.autograd.Function): + """ + Custom autograd function for column-wise parallelism. + """ + + @staticmethod + def symbolic(graph, input): + """Symbolic function for tracing.""" + return dist.all_reduce(input.contiguous(), dist.get_tensor_model_parallel_group()) + + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, input: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + """ + ctx.group = group + return input + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[None, torch.Tensor]: + """ + Backward pass. + """ + if ctx.group == None: + return None, grad_output + + dist.all_reduce(grad_output.contiguous(), group=ctx.group) + return None, grad_output + + +class Replaced_Layer(nn.Module, ABC): + """ + A base class for model layers with tensor parallelism support. + This class is designed to be extended by specific layers that require distributed + operations and parameter gather/partitioning during inference or training. + + Attributes: + mode (str): The mode of operation[INFERENCE or TRAINING], default is "INFERENCE". + mp_group (Optional[dist.ProcessGroup]): The process group used for model parallelism. + tp_world_size (int): The world size of tensor parallelism, i.e., the number of parallel workers. + tp_index (int): The rank (ID) of the current worker in tensor parallelism. + support_training (bool): Flag indicating whether the layer supports training (default: False). + name (Optional[str]): The name of the layer, if provided. + """ + + def __init__(self, mp_group: Optional[dist.ProcessGroup], **kwargs: Any): + """ + Initializes the Replaced_Layer with optional model parallelism group and layer name. + + Args: + mp_group (Optional[dist.ProcessGroup]): The process group for model parallelism. + If None, no model parallelism is set. + """ + super().__init__() + self.support_training: bool = False + if mp_group is not None: + self.mp_group = mp_group + self.tp_world_size: int = dist.get_world_size(self.mp_group) + self.tp_index: int = dist.get_rank(mp_group) + + # backward compatibility + self.world_size = self.tp_world_size + self.rank = self.tp_index + + self.name = getattr(self, 'name', None) + if kwargs.get('name') is not None: + self.name = kwargs.get('name') # Set the layer name if provided. + + @abstractmethod + def forward(self, input): + """ + Forward pass method. Must be implemented by subclasses to define layer-specific operations. + """ + pass + + @abstractmethod + def gather_params(self, params_list): + """ + Gathers parameters across devices for distributed training. Must be implemented by subclasses in "TRAINING" mode. + """ + pass + + @abstractmethod + def partition(self, params_list: List[torch.Tensor]): + """ + Partitions the parameters for tensor parallelism. + It is necessary to ensure that this function only involves the logic of params partitioning. + """ + pass + + def config_tp_params(self, weight): + """ + Configures the weight tensor for training with tensor parallelism. This includes enabling gradients + and associating necessary methods for parameter gathering and partitioning. + + Args: + weight (Optional[torch.Tensor]): The weight tensor to configure for tensor parallelism. + If None, no action is taken. + """ + # # The RNG states have already been synchronized in init_inference. + if self.is_training_mode(): + assert self.support_training, "No implementation of backward." + if weight is not None: + if self.is_training_mode(): + if weight.requires_grad is None: + weight.requires_grad = True + else: + weight.requires_grad = False + setattr(weight, DS_TENSOR_MODEL_PARALLEL, True) + setattr(weight, DS_IS_REPLACED_MODULE, True) + weight.gather_params = self.gather_params + weight.partition = self.partition + + def is_training_mode(self): + global DEEPSPEED_AUTOTP_MODE + return DEEPSPEED_AUTOTP_MODE == AUTOTP_MODE.TRAINING + + def __deepcopy__(self, memo): + # This function is designed for + # 'mp_group' (a 'ProcessGroup') cannot be pickled during deepcopy in some usage. + cls = self.__class__ + new_obj = cls.__new__(cls) + + for key, value in vars(self).items(): + if key == 'mp_group': + new_obj.mp_group = self.mp_group + else: + setattr(new_obj, key, deepcopy(value, memo)) + + memo[id(self)] = new_obj + return new_obj + + def extra_repr(self): + if self.weight is not None: + out_features, in_features = self.weight.shape if self.weight is not None else (None, None) + dtype = self.weight.dtype if self.weight is not None else None + extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( + in_features, out_features, self.bias is not None, dtype) + return extra_repr_str + + +class GatherReplacedLayerParams: + """ + A context manager for gathering parameters of a replaced layer, enabling partitioning and gathering functionality + based on the configuration of the model. + """ + + def __init__(self, + params: Union[Iterable[torch.Tensor], torch.Tensor], + module: torch.nn.Module, + enabled: bool = True): + """ + Initialize the context manager to handle parameter gathering and partitioning for a replaced layer. + + Args: + params (Iterable or torch.Tensor): A collection or single parameter to manage. + module (torch.nn.Module): The module that these parameters belong to. + enabled (bool): Flag indicating whether the parameter management is enabled (default: True). + """ + self.enabled = enabled + self.module = module + if not enabled: + return + + # Ensure params is a list, whether it's a single param or iterable (e.g., model.parameters()) + if isinstance(params, Iterable) and not isinstance(params, torch.Tensor): + self.params: List[torch.Tensor] = list(params) # Convert generators to a list for multiple iterations + else: + self.params: List[torch.Tensor] = [params] # Wrap single parameter in a list for uniform processing + + # Check if the parameters belong to a replaced layer (indicated by a specific attribute) + if not any(self._is_replaced_module_weight(p) for p in params): + self.enabled = False + return + + def _is_replaced_module_weight(self, param: torch.Tensor) -> bool: + """ + Helper function to determine if a parameter belongs to a replaced module. + + Args: + param (torch.Tensor): The parameter to check. + + Returns: + bool: True if the parameter belongs to a replaced module, False otherwise. + """ + return getattr(param, DS_IS_REPLACED_MODULE, False) + + def __enter__(self) -> None: + """ + Enter the context manager. If enabled, gather parameters for the replaced module. + """ + if self.enabled: + self.params[0].gather_params(self.params) + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Exit the context manager. If enabled, partition the parameters for the replaced module. + """ + #TODO : Check whether there are any missing attributes. + if self.enabled: + self.params[0].partition(self.params) + + +class LinearAllreduce(Replaced_Layer): + + def __init__(self, module, mp_group, **kwargs): + super(LinearAllreduce, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + + self.partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_tp_params(self.bias) + + def forward(self, input): + output = torch.matmul(input, self.weight.transpose(-1, -2)) + output = RowParallel.apply(self.mp_group, output, not self.is_training_mode()) + if self.bias is not None: + output += self.bias + return output + + @torch.no_grad() + def gather_params(self, params_list): + + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't gather bias + return + params_list[idx].data_partition = param.data + param = param.transpose(0, 1).contiguous() + output_param = torch.empty(self.tp_world_size * param.shape[0], + param.shape[1], + dtype=param.dtype, + device=param.device) + dist.all_gather_into_tensor(output_param, param, group=self.mp_group) + params_list[idx].data = output_param.transpose(0, 1).contiguous() + return + + @torch.no_grad() + def partition(self, params_list): + + if not self.is_training_mode(): + self.uneven_partition(params_list) + return + + else: + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't slipt bias + return + _partition = torch.chunk(param, self.tp_world_size, dim=-1)[self.tp_index] + + _partition = move(_partition, get_accelerator().current_device_name()).detach() + + params_list[idx].data = _partition + + def uneven_partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None or idx > 0: + # don't slipt bias + return + assert self.name is not None, "The module name must be provided in the initialization." + _partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[1], self.tp_world_size, + self.name), + dim=1)[self.tp_index] + + _partition = move(_partition, get_accelerator().current_device_name()).detach() + params_list[idx].data = _partition + + +#remove kwargs from partition. +class LinearLayer(Replaced_Layer): + + def __init__(self, module, mp_group, skip_partition=False, **kwargs): + super(LinearLayer, self).__init__(mp_group, **kwargs) + self.weight = module.weight + self.bias = module.bias + if not skip_partition: + self.partition([self.weight, self.bias]) + self.support_training = True + self.config_tp_params(self.weight) + if self.bias is not None: + self.config_tp_params(self.bias) + + def forward(self, input): + input = ColumnParallel.apply(self.mp_group, input) + output = torch.matmul(input, self.weight.transpose(-1, -2)) + if self.bias is not None: + output += self.bias + return output + + @torch.no_grad() + def gather_params(self, params_list): + # Does not support uneven shard. + for idx, param in enumerate(params_list): + + params_list[idx].data_partition = param.data + output_param = torch.empty(self.tp_world_size * param.shape[0], + param.shape[1], + dtype=param.dtype, + device=param.device) + dist.all_gather_into_tensor(output_param, param, group=self.mp_group) + params_list[idx].data = output_param.contiguous() + + @torch.no_grad() + def partition(self, params_list): + + if not self.is_training_mode(): + self.uneven_partition(params_list) + return + for idx, param in enumerate(params_list): + if param is None: + return + #split bias if provide + _partition = torch.chunk(param, self.tp_world_size, dim=0)[self.tp_index] + + _partition = move(_partition, get_accelerator().current_device_name()).detach() + + params_list[idx].data = _partition + + def uneven_partition(self, params_list): + + for idx, param in enumerate(params_list): + if param is None: + #split bias if provide + return + assert self.name is not None, "The module name must be provided in the initialization." + _partition = params_list[idx].split(get_shard_size_list(params_list[idx].shape[0], self.tp_world_size, + self.name), + dim=0)[self.tp_index] + + _partition = move(_partition, get_accelerator().current_device_name()).detach() + + params_list[idx].data = _partition + + # for bwc + @classmethod + def from_weights(cls, weight_shape=None, dtype=torch.half, weight=None, bias=None): + if weight is not None: + in_features = weight.shape[1] + out_features = weight.shape[0] + linear = nn.Linear(in_features, out_features, bias=(bias is not None)) + linear.weight.data = weight + if bias is not None: + linear.bias.data = bias + else: + in_features = weight_shape[1] + out_features = weight_shape[0] + linear = nn.Linear(in_features, out_features, bias=(bias is not None)) + return cls(linear, skip_partition=True) + + +class FusedModuleWrapper: + + def __init__(self, fused_module: nn.Module): + self.fused_module = fused_module + + def __getattr__(self, module): + return self.fused_module + + +class fused_LinearLayer(LinearLayer): + + def __init__(self, module, mp_group, skip_partition=False, **kwargs): + assert kwargs.get('fused_module') is not None, "'fused_module' is required but not provided" + # Use the warp class to avoid module circular references. + self.fused_module = FusedModuleWrapper(kwargs.get('fused_module')) + super().__init__(module, mp_group, skip_partition, **kwargs) + + @torch.no_grad() + def partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None: + return + + _partition = prepare_tp_fused_qkvw(self.fused_module.module, param, self.tp_world_size, self.tp_index) + + _partition = move(_partition, get_accelerator().current_device_name()).detach() + + params_list[idx].data = _partition + + +class conv_LinearLayer(LinearLayer): + + @torch.no_grad() + def partition(self, params_list): + weight = None + bias = None + if len(params_list) == 1: + weight = params_list[0] + elif len(params_list) == 2: + weight, bias = params_list[0], params_list[1] + _partition = weight.data.split(get_shard_size_list(weight.shape[0], self.tp_world_size, self.name), + dim=1)[self.tp_index] + _partition = move(_partition, get_accelerator().current_device_name()).detach() + weight.data = _partition + + if bias is not None: + _partition = bias.data.split(get_shard_size_list(weight.shape[1], self.tp_world_size, self.name), + dim=0)[self.tp_index] + _partition = move(_partition, get_accelerator().current_device_name()).detach() + + bias.data = _partition + + +#override the subclasses related to weight splitting. +class Yuan_LinearAllreduce(LinearAllreduce): + + #Yuan2 + @torch.no_grad() + def partition(self, params_list): + weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, + self.tp_world_size, False) + params_list[0].data = weight + if bias is not None: + params_list[1].data = bias + + +class Yuan_LinearLayer(LinearLayer): + #Yuan2 + @torch.no_grad() + def partition(self, params_list): + weight, bias = shard_value_with_share_qk(params_list[0].data, params_list[1], self.tp_index, + self.tp_world_size, True) + params_list[0].data = move(weight, get_accelerator().current_device_name()).detach() + if bias is not None: + params_list[1].data = move(bias, get_accelerator().current_device_name()).detach() + + +class GateUpPack_LinearLayer(LinearLayer): + # chatGLM2, chatGLM2 + @torch.no_grad() + def partition(self, params_list): + weight, bias = shard_chunk_mlp(params_list[0].data, params_list[1], self.tp_index, self.tp_world_size) + params_list[0].data = move(weight, device=get_accelerator().current_device_name()).detach() + if bias is not None: + params_list[1].data = move(bias, device=get_accelerator().current_device_name()).detach() + + +class Conv_LinearALlreduce(LinearAllreduce): + + @torch.no_grad() + def partition(self, params_list): + for idx, param in enumerate(params_list): + if param is None: + return + param.data = param.data.transpose(-1, -2).contiguous() + + _partition = param.split(get_shard_size_list(param.shape[0], self.tp_world_size, self.name), + dim=1)[self.tp_index] + + _partition = move(_partition, get_accelerator().current_device_name()).detach() + + params_list[idx].data = _partition + + +#override the subclasses related to fwd/bwd. +class LmHeadLinearAllreduce(LinearAllreduce): + + def __init__(self, module, mp_group, **kwargs): + # set the fixed name before partition + self.name = "lm_head" + + # In some tied_embedding cases, only the lm head is sharded, while the word embedding is not. + # Reinitialization is used to decouple them and prevent the word embedding from being sharded. + # This should also be effective for cases where both are sharded in tied_embedding scenarios. + + # TODO: Training scenario-related tests, is it necessary to re-implement the vocab parallel module? + module.weight = nn.Parameter(module.weight.clone().detach()) + if hasattr(module, 'bias') and module.bias is not None: + module.bias = nn.Parameter(module.bias.clone().detach()) + super().__init__(module, mp_group, **kwargs) + + def forward(self, input): + input_shard_size = get_shard_size(input.shape[-1], self.tp_world_size, "lm_head") + input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.tp_world_size, "lm_head")[0:self.tp_index]) + output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], + self.weight.transpose(-1, -2)) + if self.mp_group is not None: + dist.inference_all_reduce(output, group=self.mp_group) + if self.bias is not None: + output += self.bias + return output class TensorParallelConv2d(nn.Module): @@ -75,97 +643,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return out -class LinearAllreduce(nn.Module): - - def __init__(self, weight, bias=None, mp_group=None): - super(LinearAllreduce, self).__init__() - self.weight = weight - self.bias = bias - self.mp_group = mp_group - - def forward(self, input): - output = torch.matmul(input, self.weight.transpose(-1, -2)) - if self.mp_group is not None: - dist.inference_all_reduce(output, group=self.mp_group) - if self.bias is not None: - output += self.bias - return output - - def extra_repr(self): - out_features, in_features = self.weight.shape if self.weight is not None else (None, None) - dtype = self.weight.dtype if self.weight is not None else None - extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( - in_features, out_features, self.bias is not None, dtype) - return extra_repr_str - - -class LmHeadLinearAllreduce(nn.Module): - - def __init__( - self, - weight, - rank, - world_size, - bias=None, - mp_group=None, - ): - super(LmHeadLinearAllreduce, self).__init__() - self.weight = weight - self.bias = bias - self.mp_group = mp_group - self.rank = rank - self.world_size = world_size - - def forward(self, input): - input_shard_size = get_shard_size(input.shape[-1], self.world_size, "lm_head") - input_shard_offset = sum(get_shard_size_list(input.shape[-1], self.world_size, "lm_head")[0:self.rank]) - output = torch.matmul(input[:, :, input_shard_offset:input_shard_offset + input_shard_size], - self.weight.transpose(-1, -2)) - if self.mp_group is not None: - dist.inference_all_reduce(output, group=self.mp_group) - if self.bias is not None: - output += self.bias - return output - - def extra_repr(self): - out_features, in_features = self.weight.shape if self.weight is not None else (None, None) - dtype = self.weight.dtype if self.weight is not None else None - extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( - in_features, out_features, self.bias is not None, dtype) - return extra_repr_str - - -class LinearLayer(nn.Module): - - def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None): - super(LinearLayer, self).__init__() - if weight is not None: - self.weight = weight - self.bias = bias - else: - self.weight = Parameter( - torch.empty(weight_shape, dtype=dtype, device=get_accelerator().current_device_name())) - - self.bias = Parameter( - torch.empty(weight_shape[0], - dtype=dtype, - device=get_accelerator().current_device_name())) \ - if bias is not None else None - - def forward(self, input): - output = torch.matmul(input, self.weight.transpose(-1, -2)) - if self.bias is not None: - output += self.bias - return output - - def extra_repr(self): - out_features, in_features = self.weight.shape - dtype = self.weight.dtype - extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format( - in_features, out_features, self.bias is not None, dtype) - return extra_repr_str - - class Normalize(nn.Module): def __init__(self, dim=None, dtype=torch.float, eps=1e-5, weight=None, bias=None): diff --git a/deepspeed/module_inject/load_checkpoint.py b/deepspeed/module_inject/load_checkpoint.py index 4d01fdc69869..862628fa7b4b 100644 --- a/deepspeed/module_inject/load_checkpoint.py +++ b/deepspeed/module_inject/load_checkpoint.py @@ -236,7 +236,7 @@ def load_module_recursive(module, prefix='', level=0): child.weight.ds_id in all_ds_ids): prefix1 = all_ds_ids[child.weight.ds_id] if child.__class__ is nn.Linear: - child = LinearLayer(weight=all_ds_ids[child.weight.ds_id]) + child = LinearLayer.from_weights(weight=all_ds_ids[child.weight.ds_id]) setattr(module, name, child) continue child_params = list(child.parameters()) @@ -249,7 +249,9 @@ def load_module_recursive(module, prefix='', level=0): child = Normalize(dim=ds_shape[-1], dtype=child.weight.dtype, eps=child.eps) setattr(module, name, child) elif child.__class__ in [nn.Linear, ColumnParallelLinear, RowParallelLinear]: - child = LinearLayer(weight_shape=child.weight.shape, dtype=child.weight.dtype, bias=child.bias) + child = LinearLayer.from_weights(weight_shape=child.weight.shape, + dtype=child.weight.dtype, + bias=child.bias) setattr(module, name, child) elif child.__class__ is OPTLearnedPositionalEmbedding: child = OPTEmbedding(weight_shape=ds_shape) diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 26d242d33e2f..9510f96b89c6 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -15,7 +15,7 @@ from .replace_policy import replace_policies, generic_policies from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from .layers import TensorParallelOcShardConv2d, TensorParallelIcShardConv2d - +from deepspeed.module_inject.layers import is_autotp_training_mode from deepspeed import comm as dist from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads, set_tp_grain_size @@ -323,7 +323,7 @@ def replace_fn(child, _policy, layer_id=0, prefix="", state_dict=None): else: # copy relevant state from child -> new module - if config.replace_with_kernel_inject: + if not is_autotp_training_mode() and config.replace_with_kernel_inject: new_module = replace_with_policy(child, _policy, config.triangular_masking, @@ -475,7 +475,7 @@ def conv2d_parallel_shard_weights(model, rank, world_size): set_lm_head(replaced_module) print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") - if config.save_mp_checkpoint_path is not None: + if not is_autotp_training_mode() and config.save_mp_checkpoint_path is not None: from collections import OrderedDict import json num_partitions = 8 diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index fb786f29722d..b6dabc161e8c 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -62,6 +62,7 @@ from ..compression.constants import * from .swap_tensor.aio_config import get_aio_config +from .tensor_parallel import get_tensor_parallel_config from .data_pipeline.config import get_data_efficiency_enabled, get_data_efficiency_config, get_curriculum_enabled_legacy, get_curriculum_params_legacy from .data_pipeline.constants import * @@ -913,6 +914,7 @@ def _initialize_params(self, param_dict): **param_dict['weight_quantization']) if 'weight_quantization' in param_dict else None self.timers_config = get_timers_config(param_dict) + self.tensor_parallel_config = get_tensor_parallel_config(param_dict) def _batch_assertion(self): diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 97d2afb8b723..986b68dc1bb1 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -37,6 +37,7 @@ from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.linear.optimized_linear import LoRAOptimizedLinear +from deepspeed.module_inject.layers import GatherReplacedLayerParams from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ @@ -75,7 +76,7 @@ from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from deepspeed.runtime.utils import clip_grad_norm_ +from deepspeed.runtime.utils import clip_grad_norm_, compare_tensors_in_structures from deepspeed.runtime.eigenvalue import Eigenvalue from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ @@ -230,7 +231,6 @@ def __init__(self, self._step_applied = False self._global_grad_norm = None self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. - self.checkpoint_engine = None self._is_gradient_accumulation_boundary = None @@ -247,6 +247,8 @@ def __init__(self, self._do_args_sanity_check(args) self._configure_with_arguments(args, mpu) self._do_sanity_check() + if self.autotp_size() > 1: + self._configure_tensor_parallel_states(model) see_memory_usage(f"DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) if mpu is not None: if self.elasticity_enabled(): @@ -411,6 +413,71 @@ def _optimized_linear_offload_setup(self): else: p.ds_offload = False + def _configure_tensor_parallel_states(self, model): + """ + Configures the tensor parallel states for the model. + This includes setting up the tensor parallel groups, initializing the TP mesh, + and registering a pre-hook to ensure that the Dataloader inputs are consistent across ranks. + """ + self._set_client_model(model) + + # sanity check + # currently, the compatibility between 'autotp' and 'zero > 1' has not been validated + assert self.zero_optimization_stage( + ) <= 1, "Currently, the compatibility between 'autotp' and 'zero_stage > 1' has not been validated" + + self.mpu = groups + self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size()) + + self.first_dataloader_check = None + + def check_dataloader_inputs_same_across_ranks(module, args, kwargs): + + def broadcast_and_check(args, bcast_rank, bcast_group): + if isinstance(args, tuple): + args = list(args) + if len(args) > 0: + if self.mpu.get_tensor_model_parallel_rank() == 0: + _src_args = [args] + dist.broadcast_object_list(object_list=_src_args, + src=bcast_rank, + group=bcast_group, + device=get_accelerator().current_device()) + # Rank 0 does not need to compare with itself + is_equal = True + else: + _src_args = [None] + dist.broadcast_object_list(object_list=_src_args, + src=bcast_rank, + group=bcast_group, + device=get_accelerator().current_device()) + + is_equal = compare_tensors_in_structures(args, _src_args[0]) + + equal_tensor = torch.tensor(is_equal, + dtype=self.communication_data_type, + device=get_accelerator().current_device()) + dist.all_reduce(equal_tensor, group=bcast_group) + assert torch.equal( + equal_tensor, + torch.tensor(groups.get_tensor_model_parallel_world_size(), + dtype=self.communication_data_type, + device=get_accelerator().current_device()) + ), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." + + bcast_rank = self.mpu.get_tensor_model_parallel_src_rank() + bcast_group = self.mpu.get_tensor_model_parallel_group() + + broadcast_and_check(args, bcast_rank, bcast_group) + broadcast_and_check(kwargs, bcast_rank, bcast_group) + + logger.info(f":The Dataloader has passed the TP group consistency check.") + self.first_dataloader_check.remove() + + self.first_dataloader_check = self.module.register_forward_pre_hook(check_dataloader_inputs_same_across_ranks, + prepend=True, + with_kwargs=True) + def destroy(self): if self.optimizer is not None and hasattr(self.optimizer, 'destroy'): self.optimizer.destroy() @@ -832,6 +899,9 @@ def zero_legacy_stage1(self): def zero_ignore_unused_parameters(self): return self._config.zero_config.ignore_unused_parameters + def autotp_size(self): + return self._config.tensor_parallel_config.autotp_size + def graph_harvesting(self): return self._config.graph_harvesting @@ -3569,6 +3639,52 @@ def _save_zero_checkpoint(self, save_path, tag): ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') + def _replace_module_consolidated_state_dict(self): + """ + Get a full non-partitioned state_dict with fp16 weights on cpu. + Important: this function must be called on all ranks and not just rank 0. + This is similar to nn.Module.state_dict (modelled after _save_to_state_dict) + This method is used for tensor parallel training. + + Returns: + OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None. + """ + #TODO: If we use both Zero3 and tensor parallel simultaneously + # we need to consolidate the gather mechanisms of both. + state_dict = OrderedDict() if dist.get_rank() == 0 else None + + def get_layer_state_dict(module, prefix=""): + with GatherReplacedLayerParams(list(module.parameters(recurse=False)), module, enabled=True): + for name, param in module.named_parameters(recurse=False): + if param is None: + continue + key = prefix + name + if (dist.get_rank() == 0): + state_dict[key] = param.detach().cpu() + # print(key,module, param.detach().cpu().shape) + + for name, child in module.named_children(): + if child is not None: + get_layer_state_dict(child, prefix + name + ".") + + get_layer_state_dict(self.module, prefix="") + + # ensure that all GPU communication tasks are completed before the process exits + get_accelerator().synchronize() + return state_dict + + def _consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): + """ + Consolidate the 16-bit state dictionary. + """ + if self.zero_optimization_stage() == ZeroStageEnum.weights: + return self._zero3_consolidated_16bit_state_dict(exclude_frozen_parameters) + elif self.autotp_size() > 1: + return self._replace_module_consolidated_state_dict() + + raise ValueError("consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, " + "including Zero Stage 3 and tensor parallelism.") + def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): """ Get a full non-partitioned state_dict with fp16 weights on cpu. diff --git a/deepspeed/runtime/tensor_parallel/__init__.py b/deepspeed/runtime/tensor_parallel/__init__.py new file mode 100644 index 000000000000..388239345351 --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .config import AUTOTP_MODE, get_tensor_parallel_config +from .tp_manager import TpTrainingManager diff --git a/deepspeed/runtime/tensor_parallel/config.py b/deepspeed/runtime/tensor_parallel/config.py new file mode 100644 index 000000000000..1300bf9323cd --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/config.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from enum import Enum +from deepspeed.runtime.config_utils import DeepSpeedConfigModel +import torch +from pydantic import Field +from typing import Optional + + +class AUTOTP_MODE(Enum): + TRAINING = "TRAINING" + INFERENCE = "INFERENCE" + + +class TPConfig(DeepSpeedConfigModel): + """ Configure tensor parallelism settings """ + + tp_size: int = 1 + """ Number of devices to split the model across using tensor parallelism. """ + + tp_grain_size: int = 1 + "The variable required by the autoTP parser has not been activated in training yet" + "as it depends on the gather logic that supports uneven partitioning. " + "Desired MLP/lm_head tp size granularity. DNN library favors tensor size in granularity of power of 2, we pick 64 as a default size." + + mpu: object = None + """ + A model parallelism unit object that implements + ``get_{model,data}_parallel_{rank,group,world_size}()``. + """ + + tp_group: object = None + + +class TPTrainingConfig(DeepSpeedConfigModel): + + dtype: torch.dtype = torch.float16 + """ + Desired model data type, will convert model to this type. + """ + + autotp_size: int = 0 + """ + In automatic tensor-parallelism training, 'tensor_parallel_size' + When set to 0, indicates that it is disabled. + """ + tensor_parallel: TPConfig = Field({}, alias="tp") + """ + Configuration for tensor parallelism used to split the model across several + GPUs. Expects a dictionary containing values for :any:`DeepSpeedTPConfig`. + """ + + injection_policy_tuple: Optional[tuple] = None + #The following parameters are required by autoTP parser. + ######################################## + keep_module_on_host: bool = False + """ + When loading checkpoints to model parameters, they are moved to the device. In very large models + this might fill the device and cause OOM. Setting this flag to true, will keep checkpoints on + host and not move them directly to the device (giving an option to quantize checkpoint data before + moving it to the device for example). + """ + + replace_with_kernel_inject: bool = Field(False, alias="kernel_inject") + """ + Set to true to inject inference kernels for models such as, Bert, GPT2, + GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two + linear layers as a tuple: + `(attention_output projection, transformer output projection)` + """ + ######################################## + + +def get_tensor_parallel_config(ds_config): + + if 'tensor_parallel' in ds_config: + return TPTrainingConfig(**ds_config['tensor_parallel']) + return TPTrainingConfig() diff --git a/deepspeed/runtime/tensor_parallel/tp_manager.py b/deepspeed/runtime/tensor_parallel/tp_manager.py new file mode 100644 index 000000000000..cf0b5a75c92a --- /dev/null +++ b/deepspeed/runtime/tensor_parallel/tp_manager.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +from .config import TPTrainingConfig, TPConfig +from deepspeed.utils import groups +import deepspeed.comm as dist + + +class TpTrainingManager(): + + def __init__(self, model, tp_size, dtype): + self.module = model + self.config = self._initialize_config(dtype) + + from deepspeed.module_inject.auto_tp import AutoTP + from deepspeed import get_accelerator + + # Parse model configuration + parser_dict = AutoTP.tp_parser(model) + print("AutoTP: ", parser_dict) + + # Initialize TP configuration and model + self._initialize_tp_config(tp_size) + self._get_model_config_generate() + + # Synchronize random number generator state across devices + _rng_state = get_accelerator().get_rng_state().to(get_accelerator().current_device_name()) + dist.broadcast(_rng_state, groups.get_tensor_model_parallel_src_rank(), self.tp_config.tp_group) + get_accelerator().set_rng_state(_rng_state.cpu()) + + # Apply injection policies + self._apply_policies(parser_dict) + + def _initialize_config(self, dtype): + """Initialize and return the DeepSpeed TP training configuration.""" + config = TPTrainingConfig() + config.dtype = dtype + return config + + def _apply_policies(self, parser_dict): + """Apply injection policies to the parsed modules.""" + for client_module, injection_policy in parser_dict: + self.config.injection_policy_tuple = injection_policy + self._apply_injection_policy(self.config, client_module) + + def _apply_injection_policy(self, config, client_module=None): + from deepspeed.module_inject import replace_transformer_layer + """Apply the given injection policy to a client module.""" + if isinstance(self.module, torch.nn.Module): + replace_transformer_layer(client_module, self.module, None, self.config, self.model_config) + + def _initialize_tp_config(self, tp_size): + """Perform TP configuration initialization.""" + self.tp_config = TPConfig() + self.tp_config.tp_size = tp_size + + groups._init_tp_mesh_device(tp_size) + self.tp_config.tp_group = groups.get_tensor_model_parallel_group() + self.config.tensor_parallel = self.tp_config + + def _get_model_config_generate(self): + """Generate and apply HF model configuration.""" + self.model_config = getattr(self.module, 'config', None) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index f48adb58c9bf..91fe7cbdcc96 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -22,7 +22,7 @@ from torch._six import inf except ModuleNotFoundError: from torch import inf - +from typing import Union, List, Dict from deepspeed import comm as dist from deepspeed.moe.utils import is_moe_param from deepspeed.utils import groups, logger @@ -1101,3 +1101,46 @@ def move_back_key(state, key): move_back_key(state, "exp_avg") if "exp_avg_sq" in state: move_back_key(state, "exp_avg_sq") + + +def compare_tensors_in_structures(inputs1: Union[List, Dict], inputs2: Union[List, Dict]) -> bool: + """ + Compare two lists or dictionaries for equality, including any tensors they may contain. + + Args: + inputs1: First input, either a list or a dictionary. + inputs2: Second input, either a list or a dictionary. + + Returns: + True if inputs1 and inputs2 are equal; False otherwise. + """ + if type(inputs1) != type(inputs2): # Ensure types match + return False + + if isinstance(inputs1, list) and isinstance(inputs2, list): + if len(inputs1) != len(inputs2): + return False + for val1, val2 in zip(inputs1, inputs2): + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + val1 = val1.to(get_accelerator().current_device()) + val2 = val2.to(get_accelerator().current_device()) + if not torch.equal(val1, val2): + return False + elif val1 != val2: + return False + return True + + elif isinstance(inputs1, dict) and isinstance(inputs2, dict): + if inputs1.keys() != inputs2.keys(): + return False + for key in inputs1: + val1 = inputs1[key].to(get_accelerator().current_device()) + val2 = inputs2[key].to(get_accelerator().current_device()) + if isinstance(val1, torch.Tensor) and isinstance(val2, torch.Tensor): + if not torch.equal(val1, val2): + return False + elif val1 != val2: + return False + return True + + return False diff --git a/deepspeed/utils/groups.py b/deepspeed/utils/groups.py index e9550a0ec25a..6dc750035061 100755 --- a/deepspeed/utils/groups.py +++ b/deepspeed/utils/groups.py @@ -46,8 +46,6 @@ # All to All quantized graident communication groups _ALL_TO_ALL_GROUP = {} -_DATA_PARALLEL_GROUP = None - mesh_device = None @@ -64,6 +62,127 @@ def _ensure_divisibility(numerator, denominator): assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator) +# ======== Start: Tensor Parallel Group Attributes ======== + +# Intra-layer model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None + +# Model parallel group (both intra- and pipeline) that the current rank belongs to. +_MODEL_PARALLEL_GROUP = None +# Data parallel group that the current rank belongs to. +_DATA_PARALLEL_GROUP = None + +# These values enable us to change the mpu sizes on the fly. +_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None +_MPU_TENSOR_MODEL_PARALLEL_RANK = None + + +def _init_tp_mesh_device(tensor_model_parallel_size=1, data_parallel_size=None): + """Initialize model data parallel groups.""" + + global _DATA_PARALLEL_GROUP + global _MODEL_PARALLEL_GROUP + global _TENSOR_MODEL_PARALLEL_GROUP + + if _TENSOR_MODEL_PARALLEL_GROUP is not None: + return + + if data_parallel_size is None: + data_parallel_size = dist.get_world_size() // tensor_model_parallel_size + + mesh_device = dist.initialize_mesh_device((data_parallel_size, tensor_model_parallel_size), + ("data_parallel", "tensor_parallel")) + _TENSOR_MODEL_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="tensor_parallel") + _DATA_PARALLEL_GROUP = mesh_device.get_group(mesh_dim="data_parallel") + + # They are always equal only in 2D (DP + TP) parallelism. + # _MODEL_PARALLEL_GROUP is assigned the same value as _TENSOR_MODEL_PARALLEL_GROUP + # to allow for future potential changes. + _MODEL_PARALLEL_GROUP = _TENSOR_MODEL_PARALLEL_GROUP + + return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \ + 'intra_layer_model parallel group is not initialized' + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_model_parallel_group(): + """Get the model parallel group the caller rank belongs to.""" + + assert _MODEL_PARALLEL_GROUP is not None, \ + 'model parallel group is not initialized' + return _MODEL_PARALLEL_GROUP + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def set_tensor_model_parallel_world_size(world_size): + """Set the tensor model parallel size""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None: + return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE + return dist.get_world_size(group=get_tensor_model_parallel_group()) + + +def get_model_parallel_world_size(): + return get_tensor_model_parallel_world_size() + + +def set_tensor_model_parallel_rank(rank): + """Set tensor model parallel rank.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + _MPU_TENSOR_MODEL_PARALLEL_RANK = rank + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + global _MPU_TENSOR_MODEL_PARALLEL_RANK + if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None: + return _MPU_TENSOR_MODEL_PARALLEL_RANK + return dist.get_rank(group=get_tensor_model_parallel_group()) + + +def get_model_parallel_rank(): + return get_tensor_model_parallel_rank() + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = dist.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + return dist.get_world_size(group=get_data_parallel_group()) + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return dist.get_rank(group=get_data_parallel_group()) + + +# ======== End: Tensor Parallel Group Attributes ======== + + # Not currently used. Helper function to create a model (tensor) parallel group. def _create_model_parallel(model_parallel_size_): """ diff --git a/tests/unit/model_parallelism/test_autotp_training.py b/tests/unit/model_parallelism/test_autotp_training.py new file mode 100644 index 000000000000..fc1f0624ec87 --- /dev/null +++ b/tests/unit/model_parallelism/test_autotp_training.py @@ -0,0 +1,574 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import deepspeed.comm as dist +import torch +import math +from copy import deepcopy + +from unit.common import DistributedTest, preferred_dtype +import deepspeed +from deepspeed.accelerator import get_accelerator +from unit.simple_model import SimpleModel, random_dataloader +from deepspeed.utils import groups +from contextlib import contextmanager +from torch import nn +from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer, set_autotp_mode +from unit.checkpoint.common import compare_lr_scheduler_states, compare_optimizer_states +import os + + +def skip_on_device(): + if get_accelerator().device_name() == 'xpu': + pytest.skip(f"XPU requires a higher version for test") + + +class SequentialLinearModel(torch.nn.Module): + + def __init__(self, hidden_dim, empty_grad=False, nlayers=1): + super(SequentialLinearModel, self).__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, hidden_dim, bias=None) for i in range(nlayers)]) + if empty_grad: + self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=None) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + self.empty_grad = empty_grad + + def forward(self, x, y): + if len(self.linears) == 1: + x = self.linears[0](x) + else: + for i, l in enumerate(self.linears): + x = self.linears[i](x) + return self.cross_entropy_loss(x, y) + + +@contextmanager +def should_assert_with_msg(expected_message): + try: + yield + except AssertionError as e: + if dist.get_rank() == 0: + print(expected_message) + print(str(e)) + if str(e) == expected_message: + pass + else: + raise e + + +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpParallelStates(DistributedTest): + world_size = 4 + + def test(self, tp_size: int): + skip_on_device() + set_autotp_mode(training=True) + + dp_size = 4 / tp_size + hidden_dim = 128 + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": 0 + } + } + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + assert groups.get_tensor_model_parallel_world_size() == tp_size + assert groups.get_data_parallel_world_size() == dp_size + + +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpDataloaderCorrectness(DistributedTest): + world_size = 4 + reuse_dist_env = True + + def test(self, tp_size: int): + skip_on_device() + hidden_dim = 128 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=3, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + dist.barrier() + with should_assert_with_msg( + "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." + ): + for batch in data_loader: + # batch[0].requires_grad = requires_grad + batch[0] += dist.get_rank() + model(batch[0], batch[1]) + + model = SimpleModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=3, + hidden_dim=hidden_dim, + device=model.device, + dtype=preferred_dtype()) + for batch in data_loader: + dist.broadcast(batch[0], + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group()) + dist.broadcast(batch[1], + src=groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group()) + model(batch[0], batch[1]) + + +def process_linear_layer(hidden_dim, input): + torch.manual_seed(42) + torch_linear = nn.Linear(hidden_dim, + hidden_dim, + dtype=preferred_dtype(), + device=get_accelerator().current_device(), + bias=None) + torch_out = torch_linear(input) + torch_loss = torch_out.sum() + torch_loss.backward() + return torch_linear, torch_out + + +@pytest.mark.sequential +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpLayerFwdBwd(DistributedTest): + world_size = 4 + reuse_dist_env = True + + def testRowParallel(self, tp_size: int): + skip_on_device() + hidden_dim = 128 + batch_size_per_device = 1 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + model = SequentialLinearModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + input = torch.randn(batch_size_per_device, + hidden_dim, + dtype=preferred_dtype(), + requires_grad=True, + device=get_accelerator().current_device()) + + dist.broadcast(input, + groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group()) + + torch_linear, torch_out = process_linear_layer(hidden_dim, input) + linear = LinearAllreduce(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + + input_ = torch.chunk(input, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] + out = linear(input_.to(get_accelerator().current_device())) + loss = out.sum() + loss.backward() + + torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=1)[groups.get_tensor_model_parallel_rank()] + assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3) + assert torch.allclose(out, torch_out.to(get_accelerator().current_device()), atol=1e-3) + + def testColumnParallel(self, tp_size: int): + skip_on_device() + hidden_dim = 128 + batch_size_per_device = 1 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + model = SequentialLinearModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + input = torch.randn(batch_size_per_device, + hidden_dim, + dtype=preferred_dtype(), + requires_grad=True, + device=get_accelerator().current_device()) + dist.broadcast(input, + groups.get_tensor_model_parallel_src_rank(), + group=groups.get_tensor_model_parallel_group()) + + torch_linear, torch_out = process_linear_layer(hidden_dim, input) + + linear = LinearLayer(deepcopy(torch_linear), groups.get_tensor_model_parallel_group()) + + out = linear(input.to(get_accelerator().current_device())) + loss = out.sum() + loss.backward() + + cur_device_out = torch.chunk(torch_out, tp_size, dim=-1)[groups.get_tensor_model_parallel_rank()] + torch_grad = torch.chunk(torch_linear.weight.grad, tp_size, dim=0)[groups.get_tensor_model_parallel_rank()] + assert torch.allclose(linear.weight.grad, torch_grad.to(get_accelerator().current_device()), atol=1e-3) + assert torch.allclose(cur_device_out.to(get_accelerator().current_device()).contiguous(), + out.contiguous(), + atol=1e-3) + + +@pytest.mark.sequential +class TestParamsGather(DistributedTest): + world_size = 4 + reuse_dist_env = True + + @pytest.mark.parametrize("layer_type", ["linear", "linearallreduce"]) + def test(self, layer_type): + skip_on_device() + tp_size = 4 + hidden_dim = 128 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": 0, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + torch.manual_seed(42) + model = SequentialLinearModel(hidden_dim=hidden_dim) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + torch_linear = nn.Linear(hidden_dim, hidden_dim, dtype=preferred_dtype(), device="cpu", bias=None) + total_params = sum(p.numel() for p in torch_linear.parameters()) + + tp_layer = None + if layer_type == "linear": + tp_layer = LinearLayer(torch_linear, groups.get_tensor_model_parallel_group()) + elif layer_type == "linearallreduce": + tp_layer = LinearAllreduce(torch_linear, groups.get_tensor_model_parallel_group()) + else: + raise ValueError(f"Invalid linear type: {config_dict['linear_type']}") + + tp_params = sum(p.numel() for p in tp_layer.parameters()) + + assert total_params // tp_size == tp_params + for name, param in tp_layer.named_parameters(recurse=False): + param.gather_params([param]) + + is_same_weights = all( + torch.equal(param1, param2) for param1, param2 in zip(tp_layer.parameters(), torch_linear.parameters())) + + assert is_same_weights + + params1 = sum(p.numel() for p in tp_layer.parameters()) + assert total_params == params1 + + for name, param in tp_layer.named_parameters(recurse=False): + param.partition([param]) + + tp_params2 = sum(p.numel() for p in tp_layer.parameters()) + + assert total_params // tp_size == tp_params2 + + +def dummy_init_engine(config): + # This is a dummy initialization function for the DeepSpeed engine. + # We only need to use the config to initialize the distributed settings for the test. + model = SequentialLinearModel(hidden_dim=8) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config) + + +def prepare_tp_model(hidden_dim, nlayers, linear_indices, allreduce_indices, group, return_global_copy=False): + model = SequentialLinearModel(hidden_dim=hidden_dim, nlayers=nlayers).to(preferred_dtype()) + base_model = None + if return_global_copy: + base_model = deepcopy(model) + for i in linear_indices: + layer = LinearLayer(model.linears[i], group) + model.linears[i] = layer + + for i in allreduce_indices: + layer = LinearAllreduce(model.linears[i], group) + model.linears[i] = layer + + return model, base_model + + +@pytest.mark.parametrize("zero_stage", [0, 1]) +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestSave(DistributedTest): + + world_size = 4 + reuse_dist_env = True + + def test_save_original_weight(self, tp_size: int, zero_stage: int): + skip_on_device() + hidden_dim = 64 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + dummy_init_engine(config_dict) + torch.manual_seed(42) + + model, base_model = prepare_tp_model(hidden_dim, + 8, [2, 5], [3, 6], + groups.get_tensor_model_parallel_group(), + return_global_copy=True) + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + + cur_params_numel = sum(p.numel() for p in model.parameters()) + base_params_numel = sum(p.numel() for p in base_model.parameters()) + assert cur_params_numel < base_params_numel + + tp_state_dict = model._consolidated_16bit_state_dict() + + def compare_state_dicts(state_dict1, state_dict2): + if state_dict1.keys() != state_dict2.keys(): + print("The state_dicts have different keys!") + return False + + for key in state_dict1: + if not torch.allclose(state_dict1[key], state_dict2[key], atol=1e-3): + assert state_dict1[key].device == "cpu" + print(f"Parameters for {key} are different!") + return False + + return True + + base_state_dict = base_model.state_dict() + if dist.get_rank() == 0: + # we should consider the case when zero3 is used in the future. + assert compare_state_dicts(base_state_dict, tp_state_dict), f"State_dict is not the same!" + else: + assert tp_state_dict is None, f"noly rank0 should have the state_dict" + + def test_ckpt_save(self, tmpdir, tp_size: int, zero_stage: int): + skip_on_device() + hidden_dim = 64 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "zero_optimization": { + "stage": zero_stage, + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 0, + "warmup_max_lr": 0.001, + "warmup_num_steps": 1000 + } + } + } + + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + config_dict["bf16"] = {"enabled": True} + + dummy_init_engine(config_dict) + + trained_model, _ = prepare_tp_model(hidden_dim, 8, [2, 5], [3, 6], groups.get_tensor_model_parallel_group()) + loaded_model, _ = prepare_tp_model(hidden_dim, 8, [2, 5], [3, 6], groups.get_tensor_model_parallel_group()) + + trained_model, _, _, _ = deepspeed.initialize(model=trained_model, + model_parameters=trained_model.parameters(), + config=config_dict) + torch.manual_seed(42) + + data_loader = random_dataloader(model=trained_model, + total_samples=3, + hidden_dim=hidden_dim, + device=trained_model.device, + dtype=preferred_dtype()) + ckpt_path = os.path.join(tmpdir, 'tp_saved_checkpoint') + for i, batch in enumerate(data_loader): + batch[0].requires_grad = True + loss = trained_model(batch[0], batch[1]) + loss = loss + trained_model.backward(loss) + trained_model.step() + trained_model.save_checkpoint(ckpt_path) + + loaded_model, _, _, _ = deepspeed.initialize(model=loaded_model, + model_parameters=loaded_model.parameters(), + config=config_dict) + loaded_model.load_checkpoint(ckpt_path, load_optimizer_states=True, load_lr_scheduler_states=True) + compare_optimizer_states(trained_model, loaded_model, hidden_dim, fp16=(preferred_dtype() == torch.float16)) + compare_lr_scheduler_states(trained_model, loaded_model) + + +@pytest.mark.parametrize("zero_stage", [0, 1]) +@pytest.mark.parametrize("tp_size", [2, 4]) +class TestTpGradNorm(DistributedTest): + + world_size = 4 + reuse_dist_env = True + + def test(self, tp_size: int, zero_stage: int): + skip_on_device() + hidden_dim = 64 + set_autotp_mode(training=True) + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "tensor_parallel": { + "autotp_size": tp_size + }, + "zero_optimization": { + "stage": zero_stage, + } + } + if preferred_dtype() is torch.float16: + config_dict["fp16"] = {"enabled": True} + elif preferred_dtype() is torch.bfloat16: + if zero_stage == 0: + pytest.skip( + "This test has an overflow data and needs to implement an overflow skip mechanism in BF16_optimizer" + ) + config_dict["bf16"] = {"enabled": True} + + torch.manual_seed(42) + + dummy_init_engine(config=config_dict) + tp_model, base_model = prepare_tp_model(hidden_dim, + 8, [2, 5], [3, 6], + groups.get_tensor_model_parallel_group(), + return_global_copy=True) + + base_model, base_optimizer, _, _ = deepspeed.initialize(model=base_model, + model_parameters=base_model.parameters(), + config=config_dict) + data_loader = random_dataloader(model=base_model, + total_samples=20, + hidden_dim=hidden_dim, + device=base_model.device, + dtype=preferred_dtype()) + + for i, batch in enumerate(data_loader): + batch[0].requires_grad = True + loss = base_model(batch[0], batch[1]) + loss = loss + base_model.backward(loss) + base_model.step() + + base_norm = base_optimizer._global_grad_norm + + base_model.destroy() + + tp_model, tp_optimizer, _, _ = deepspeed.initialize(model=tp_model, + model_parameters=tp_model.parameters(), + config=config_dict) + for i, batch in enumerate(data_loader): + batch[0].requires_grad = True + loss = tp_model(batch[0], batch[1]) + loss = loss + tp_model.backward(loss) + tp_model.step() + + tp_norm = tp_optimizer._global_grad_norm + + assert math.isclose(base_norm, tp_norm, abs_tol=1e-3) + tp_params_numel = sum(p.numel() for p in tp_model.parameters()) + base_params_numel = sum(p.numel() for p in base_model.parameters()) + assert tp_params_numel < base_params_numel