diff --git a/MANIFEST.in b/MANIFEST.in index ab79573ef96c..8d84aee0faf4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,8 +2,8 @@ include *.txt README.md include deepspeed/inference/v2/kernels/ragged_ops/libs/*.so include deepspeed/inference/v2/kernels/cutlass_ops/libs/*.so recursive-include requirements *.txt -recursive-include deepspeed *.cpp *.h *.cu *.hip *.tr *.cuh *.cc *.json -recursive-include csrc *.cpp *.h *.cu *.tr *.cuh *.cc +recursive-include deepspeed *.cpp *.h *.hpp *.cu *.hip *.tr *.cuh *.cc *.json +recursive-include csrc *.cpp *.h *.hpp *.cu *.tr *.cuh *.cc recursive-include op_builder *.py recursive-include benchmarks *.py recursive-include accelerator *.py diff --git a/deepspeed/linear/optimized_linear.py b/deepspeed/linear/optimized_linear.py index 138bd493ffc7..e982785a8122 100644 --- a/deepspeed/linear/optimized_linear.py +++ b/deepspeed/linear/optimized_linear.py @@ -85,7 +85,7 @@ def __init__(self, self.bias = bias self.lora_config = lora_config self.quantization_config = quantization_config - device = get_accelerator().current_device() if device is None else device + device = get_accelerator().current_device_name() if device is None else device assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config" self.zero_shards = self.lora_config.base_weight_sharding diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 88f7086518e8..3429ceb0a4ee 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -13,7 +13,7 @@ from deepspeed import comm as dist from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce from deepspeed.accelerator import get_accelerator -from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw +from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_chunk_mlp from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list @@ -133,7 +133,8 @@ def is_load_module(module): load_layers = [nn.Linear, nn.Embedding, nn.LayerNorm] load_layer_names = [ "LPLayerNorm", "SharedEmbedding", "OPTLearnedPositionalEmbedding", "LlamaRMSNorm", "FalconLinear", - "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm" + "MistralRMSNorm", "T5LayerNorm", "MixtralRMSNorm", "Phi3RotaryEmbedding", "Phi3SuScaledRotaryEmbedding", + "Phi3RMSNorm" ] return module.__class__ in load_layers or module._get_name() in load_layer_names @@ -306,6 +307,8 @@ def tp_parser(model): # Mixtral-7x8b used w2*act(w1*w3) linear. need to replace w2 to linearallreduce. elif 'w2' in layer and 'Mixtral' in str(type(module)): gem_list = gem_list + [layer] + elif 'self_attn.dense' in layer and 'Phi' in str(type(module)): + gem_list = gem_list + [layer] layer_list = [] if gem_list != []: @@ -328,6 +331,10 @@ def _replace(self, child, name, conv_linear_layer): # For mixtral-7x8b, need to skip MoE gate linear replace. if name == "block_sparse_moe.gate": return child + # for phi3. + if 'gate_up_proj' in name: + weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size()) + return LinearLayer(weight=weight, bias=bias) if name in self.all_reduce_linears: # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] # else [weight_shape[0], weight_shape[1] // mp_size] diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index cf087c16da8a..33d36fbfae54 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd, get_num_attention_heads def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): @@ -42,6 +42,7 @@ def prepare_tp_fused_qkvw(module, src, mp_size, gpu_index): "FalconDecoderLayer": 'bloomtype', "GPTBigCodeBlock": 'bigcodetype', "DecoderLayer": 'glmtype', + "Phi3DecoderLayer": "phi3type" } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): @@ -93,6 +94,20 @@ def _bigcode_type_transpose(input, mp_size): split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) return torch.cat((split_q[gpu_index], kv), dim=0) + def _phi3_type_transpose(input, mp_size): + num_kv_heads = get_num_kv_heads() + num_heads = get_num_attention_heads() + hidden_size = input.shape[1] + head_dim = hidden_size // num_heads + q_pos = input.shape[0] - 2 * num_kv_heads * head_dim + q = input[:q_pos] + k = input[q_pos:q_pos + num_kv_heads * head_dim] + v = input[q_pos + num_kv_heads * head_dim:] + split_q = q.split(get_shard_size_list(q.shape[0], mp_size), dim=0) + split_k = k.split(get_shard_size_list(k.shape[0], mp_size), dim=0) + split_v = v.split(get_shard_size_list(v.shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], split_k[gpu_index], split_v[gpu_index]), dim=0) + def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following @@ -110,6 +125,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): return _qwen_type_transpose(src, mp_size, module) elif fused_qkv_type == 'bigcodetype': return _bigcode_type_transpose(src, mp_size) + elif fused_qkv_type == 'phi3type': + return _phi3_type_transpose(src, mp_size) raise ValueError("unknown fused_qkv_type") @@ -123,3 +140,24 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None, module=None): warning_once(f"Unrecognized fusedkqv weight type, default to using bloom type," f"please check in prepare_tp_fused_qkvw() to avoid potential calculation errors") return _bloom_type_transpose(src, mp_size) + + +# For phi3 with chunk mlp, adjust the weight order. +def shard_chunk_mlp( + weight, + bias, + rank, + world_size, +): + weight_gate, weight_states = weight.chunk(2, dim=0) + total_size = weight_gate.shape[0] + split_weight_gate = weight_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_weight_states = weight_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + shard_weight = torch.cat((split_weight_gate[rank], split_weight_states[rank]), dim=0) + if bias is not None: + bias_gate, bias_states = bias.chunk(2, dim=0) + split_bias_gate = bias_gate.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + split_bias_states = bias_states.split(get_shard_size_list(total_size, world_size, "mlp"), dim=0) + return shard_weight, torch.cat((split_bias_gate[rank], split_bias_states[rank]), dim=0) + + return shard_weight, None diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index e1703562d180..3029a79698dc 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd, set_num_attention_heads from .load_checkpoint import load_model_with_checkpoint import time @@ -290,6 +290,10 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 4.2 set n_embd set_n_embd(n_embd) + # 4.3 set attention_heads + if hasattr(model_config, 'num_attention_heads'): + set_num_attention_heads(getattr(model_config, 'num_attention_heads')) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 79c19b5f1272..6758c7a657f6 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -12,6 +12,11 @@ def set_num_kv_heads(num): num_kv_heads = num +def set_num_attention_heads(num): + global num_attention_heads + num_attention_heads = num + + def set_n_embd(num): global n_embd n_embd = num @@ -22,6 +27,11 @@ def get_num_kv_heads(): return num_kv_heads +def get_num_attention_heads(): + global num_attention_heads + return num_attention_heads + + def get_shard_size(total_size, mp_size, name=None, rank=None): global num_kv_heads last_linear = ["lm_head", "embed_out"] diff --git a/deepspeed/monitor/comet.py b/deepspeed/monitor/comet.py new file mode 100644 index 000000000000..d8bc4017800f --- /dev/null +++ b/deepspeed/monitor/comet.py @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import TYPE_CHECKING, Any, Tuple, List, Dict, Optional + +from .utils import check_comet_availability +from .monitor import Monitor + +import deepspeed.comm as dist + +if TYPE_CHECKING: + import comet_ml + from .config import CometConfig + +Name = str +Value = Any +GlobalSamples = int +Event = Tuple[Name, Value, GlobalSamples] + + +class CometMonitor(Monitor): + + def __init__(self, comet_config: "CometConfig"): + super().__init__(comet_config) + check_comet_availability() + import comet_ml + + self.enabled = comet_config.enabled + self._samples_log_interval = comet_config.samples_log_interval + self._experiment: Optional["comet_ml.ExperimentBase"] = None + + if self.enabled and dist.get_rank() == 0: + self._experiment = comet_ml.start( + api_key=comet_config.api_key, + project=comet_config.project, + workspace=comet_config.workspace, + experiment_key=comet_config.experiment_key, + mode=comet_config.mode, + online=comet_config.online, + ) + + if comet_config.experiment_name is not None: + self._experiment.set_name(comet_config.experiment_name) + + self._events_log_scheduler = EventsLogScheduler(comet_config.samples_log_interval) + + @property + def experiment(self) -> Optional["comet_ml.ExperimentBase"]: + return self._experiment + + @property + def samples_log_interval(self) -> int: + return self._samples_log_interval + + def write_events(self, event_list: List[Event]) -> None: + if not self.enabled or dist.get_rank() != 0: + return None + + for event in event_list: + name = event[0] + value = event[1] + engine_global_samples = event[2] + + if self._events_log_scheduler.needs_logging(name, engine_global_samples): + self._experiment.__internal_api__log_metric__( + name=name, + value=value, + step=engine_global_samples, + ) + + +class EventsLogScheduler: + + def __init__(self, samples_log_interval: int): + self._samples_log_interval = samples_log_interval + self._last_logged_events_samples: Dict[str, int] = {} + + def needs_logging(self, name: str, current_sample: int) -> bool: + if name not in self._last_logged_events_samples: + self._last_logged_events_samples[name] = current_sample + return True + + last_logged_sample = self._last_logged_events_samples[name] + samples_delta = current_sample - last_logged_sample + + if samples_delta >= self._samples_log_interval: + self._last_logged_events_samples[name] = current_sample + return True + + return False diff --git a/deepspeed/monitor/config.py b/deepspeed/monitor/config.py index 5a8ca6ecf5cd..d422d3b1b9bb 100644 --- a/deepspeed/monitor/config.py +++ b/deepspeed/monitor/config.py @@ -3,12 +3,14 @@ # DeepSpeed Team +from typing import Optional + from deepspeed.pydantic_v1 import root_validator from deepspeed.runtime.config_utils import DeepSpeedConfigModel def get_monitor_config(param_dict): - monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor")} + monitor_dict = {key: param_dict.get(key, {}) for key in ("tensorboard", "wandb", "csv_monitor", "comet")} return DeepSpeedMonitorConfig(**monitor_dict) @@ -60,12 +62,75 @@ class CSVConfig(DeepSpeedConfigModel): """ Name for the current job. This will become a new directory inside `output_path`. """ +class CometConfig(DeepSpeedConfigModel): + """ + Sets parameters for Comet monitor. For logging data Comet uses + experiment object. + https://www.comet.com/docs/v2/api-and-sdk/python-sdk/reference/Experiment/ + """ + + enabled: bool = False + """ Whether logging to Comet is enabled. Requires `comet_ml` package is installed. """ + + samples_log_interval: int = 100 + """ Metrics will be submitted to Comet after processing every `samples_log_intervas` samples""" + + project: Optional[str] = None + """ + Comet project name. Can be set through .comet.config file or environment variable COMET_PROJECT_NAME + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + workspace: Optional[str] = None + """ + Comet workspace name. Can be set through .comet.config file or environment variable COMET_WORKSPACE + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + api_key: Optional[str] = None + """ + Comet API key. Can be set through .comet.config file or environment variable COMET_API_KEY + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + experiment_name: Optional[str] = None + """ + The name for comet experiment to be used for logging. + Can be set through .comet.config file or environment variable COMET_EXPERIMENT_NAME + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + experiment_key: Optional[str] = None + """ + The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters. + Can be set through .comet.config or environment variable COMET_EXPERIMENT_KEY + https://www.comet.com/docs/v2/guides/experiment-management/configure-sdk/#explore-comet-configuration-options + """ + + online: Optional[bool] = None + """ + If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment + Defaults to True. + """ + + mode: Optional[str] = None + """ + Control how the Comet experiment is started, 3 options are possible.: + - "get": Continue logging to an existing experiment identified by the `experiment_key` value. + - "create": Always creates of a new experiment, useful for HPO sweeps. + - "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one. + """ + + class DeepSpeedMonitorConfig(DeepSpeedConfigModel): """Sets parameters for various monitoring methods.""" tensorboard: TensorBoardConfig = {} """ TensorBoard monitor, requires `tensorboard` package is installed. """ + comet: CometConfig = {} + """ Comet monitor, requires `comet_ml` package is installed """ + wandb: WandbConfig = {} """ WandB monitor, requires `wandb` package is installed. """ @@ -75,5 +140,5 @@ class DeepSpeedMonitorConfig(DeepSpeedConfigModel): @root_validator def check_enabled(cls, values): values["enabled"] = values.get("tensorboard").enabled or values.get("wandb").enabled or values.get( - "csv_monitor").enabled + "csv_monitor").enabled or values.get("comet") return values diff --git a/deepspeed/monitor/monitor.py b/deepspeed/monitor/monitor.py index 5a32b8bbcadd..e7e26dc483d9 100644 --- a/deepspeed/monitor/monitor.py +++ b/deepspeed/monitor/monitor.py @@ -24,6 +24,7 @@ def write_events(self, event_list): from .wandb import WandbMonitor from .tensorboard import TensorBoardMonitor from .csv_monitor import csvMonitor +from .comet import CometMonitor class MonitorMaster(Monitor): @@ -33,6 +34,7 @@ def __init__(self, monitor_config): self.tb_monitor = None self.wandb_monitor = None self.csv_monitor = None + self.comet_monitor = None self.enabled = monitor_config.enabled if dist.get_rank() == 0: @@ -42,6 +44,8 @@ def __init__(self, monitor_config): self.wandb_monitor = WandbMonitor(monitor_config.wandb) if monitor_config.csv_monitor.enabled: self.csv_monitor = csvMonitor(monitor_config.csv_monitor) + if monitor_config.comet.enabled: + self.comet_monitor = CometMonitor(monitor_config.comet) def write_events(self, event_list): if dist.get_rank() == 0: @@ -51,3 +55,5 @@ def write_events(self, event_list): self.wandb_monitor.write_events(event_list) if self.csv_monitor is not None: self.csv_monitor.write_events(event_list) + if self.comet_monitor is not None: + self.comet_monitor.write_events(event_list) diff --git a/deepspeed/monitor/utils.py b/deepspeed/monitor/utils.py index 265fc9811553..f5530e8532e1 100644 --- a/deepspeed/monitor/utils.py +++ b/deepspeed/monitor/utils.py @@ -3,6 +3,8 @@ # DeepSpeed Team +from packaging import version as pkg_version + def check_tb_availability(): try: @@ -22,3 +24,14 @@ def check_wandb_availability(): 'If you want to use wandb logging, please `pip install wandb` and follow the instructions at https://docs.wandb.ai/quickstart' ) raise + + +def check_comet_availability(): + try: + import comet_ml + comet_version = pkg_version.parse(comet_ml.__version__) + if comet_version < pkg_version.Version("3.41.0"): + raise ImportError("`comet_ml` must have at least version 3.41.0") + except ImportError: + print('If you want to use comet logging, please `pip install "comet_ml>=3.41.0"`') + raise diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index b5e4e33425d0..66fe29fbbea2 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values): return field_value -class CompiledModuleWrapper(torch.nn.Module): - - def __init__(self, module, compile_config: Union[CompileConfig, None] = None): - super().__init__() - - assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." - - modules = self.__dict__.get('_modules') - modules['wrapped'] = module - self.__dict__['wrapped'] = module - self._is_compiled = False - self._backend = get_backend_fn(compile_config.backend) - self._compile_kwargs = compile_config.kwargs - self._compiler_fn = None - - def __getattr__(self, name): - return getattr(self.__dict__['wrapped'], name) - - def set_backend(self, backend: Union[str, Callable]): - """Set the backend for torch.compile. - - Args: - backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. - You can directly pass a function that works as a backend. - See also `backend` field in `CompileConfig` for more details. - """ - self._backend = get_backend_fn(backend) - - def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: - """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. - You can also pass a backend name with "backend" key to change the backend. - - Args: - kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. - """ - - if "backend" in kwargs: - raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") - self._compile_kwargs.update(kwargs) - - def set_compiler_fn(self, compiler_fn: Callable) -> None: - """Set a function to be used for compiling the module. - This function should take a torch.nn.Module as input and return a compiled module. - Note that other compile options are ignored when a compiler_fn is set. - - Example: - ```python - def my_compiler_fn(module: torch.nn.Module): - ... - return torch.compile(module, ...) - - engine.set_compiler_fn(my_compiler_fn) - ``` - """ - self._compiler_fn = compiler_fn - - def forward(self, *args, **kwargs) -> Any: - if not self.is_compiled: - if self._compiler_fn is None: - self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) - else: - self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) - self._is_compiled = True - - return self.__dict__['wrapped'](*args, **kwargs) - - @property - def is_compiled(self) -> bool: - return self._is_compiled - - @property - def backend(self) -> Union[str, Callable]: - return self._backend - - @property - def torch_compile_kwargs(self) -> Dict[str, Any]: - return self._compile_kwargs - - @property - def compiler_fn(self) -> Union[Callable, None]: - return self._compiler_fn +def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): + + class wrapper(mod.__class__): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + self.__dict__ = module.__dict__.copy() + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, + backend=self._backend, + **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn + + return wrapper(mod, compile_config) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a2b943b0992..13f335cae6d5 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -469,13 +469,6 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) - elif isinstance(_module, CompiledModuleWrapper): - try: - return getattr(_module, name) - except AttributeError: - raise AttributeError( - f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" - ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") @@ -1270,7 +1263,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): else: self.optimizer = basic_optimizer - log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer_name()), ranks=[0]) + log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), ranks=[0]) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() diff --git a/deepspeed/runtime/fp16/fused_optimizer.py b/deepspeed/runtime/fp16/fused_optimizer.py index bf1693307ea7..49093bb73c8f 100755 --- a/deepspeed/runtime/fp16/fused_optimizer.py +++ b/deepspeed/runtime/fp16/fused_optimizer.py @@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group): group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx]) grad_flat_st_idx = grad_flat_en_idx - return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device()) + return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name()) def step(self, closure=None): """ diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1dda7f1aad32..be8fe1a368c6 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine): def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) - assert isinstance(self.module, PipelineModule) \ - or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \ - "model must base PipelineModule" + assert isinstance(self.module, PipelineModule), "model must base PipelineModule" assert self.zero_optimization_stage( ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism" diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 7744b2ee8b98..2c01c3475a70 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group): # This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'. # Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group)) - scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float) + scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float) dist.all_reduce(scaled_norm_tensor, group=group) all_groups_norm = scaled_norm_tensor.item() #print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}") @@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No # # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool) # # for mask_idx in grad_norm_mask[idx]: # # mask_tensor_[mask_idx[0]:mask_idx[1]] = True - cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(), + cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(), dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1) - mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype) + mask_tensor = torch.zeros(p.shape[0] + 1, + device=get_accelerator().current_device_name(), + dtype=p.dtype) mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1), cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1] diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c6ff216edfcb..13ca29c9fceb 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1409,7 +1409,7 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): norm_is_nan = total_norm.isnan() inf_or_nan = norm_is_nan.logical_or(norm_is_inf) - err = torch.tensor(-1.0, device=self.device, dtype=torch.float) + err = torch.tensor(-1.0, device=inf_or_nan.device, dtype=torch.float) total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm return total_norm diff --git a/docs/_data/navigation.yml b/docs/_data/navigation.yml index 217d56c14812..3bd3e451ab49 100755 --- a/docs/_data/navigation.yml +++ b/docs/_data/navigation.yml @@ -41,7 +41,7 @@ lnav: - title: 'Flops Profiler' url: /docs/config-json/#flops-profiler - title: 'Monitoring' - url: /docs/config-json/#monitoring-module-tensorboard-wandb-csv + url: /docs/config-json/#monitoring-module - title: 'Communication Logging' url: /docs/config-json/#communication-logging - title: 'Model Compression' diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index abe314cbb1a6..adb2f1679ea0 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -1139,15 +1139,16 @@ DeepSpeed Data Efficiency Library includes two techniques: curriculum learning a | ---------------------------------------------------------------------------------------------------------------------------- | ------- | | List of which step to change difficulty level. One of the `schedule_config` when the `fixed_discrete` schedule_type is used. | N/A | -### Monitoring Module (TensorBoard, WandB, CSV) +### Monitoring Module **Note:** Deepspeed logs to TensorBoard through PyTorch. Logging to TensorBoard requires that the `tensorboard` package is installed (read more in the [PyTorch documentation](https://pytorch.org/docs/1.8.0/tensorboard.html)). {: .notice--warning} **Note:** Logging to WandB requires that the `wandb` package is installed (read more in the [WandB documentation](https://docs.wandb.ai/quickstart)). {: .notice--warning} +**Note:** Logging to Comet requires that the `comet_ml` package is installed (read more in the [Comet documentation](https://www.comet.com/docs/v2/guides/quickstart/#1-install-and-configure-the-comet-ml-sdk)). +{: .notice--warning} - -Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. +Deepspeed's Monitor module can log training details into a [Tensorboard](https://www.tensorflow.org/tensorboard)-compatible file, to [WandB](https://wandb.ai/site), to [Comet](https://www.comet.com/site/?utm_source=deepseed&utm_medium=docs&utm_content=docs) or to simple CSV files. Below is an overview of what DeepSpeed will log automatically. | Field | Description |Conditions | | ------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ----- | @@ -1201,6 +1202,36 @@ Example of **wandb** configuration: } ``` +**comet**: [dictionary] + +| Fields | Value | Default | +|--- |--- |--- | +| enabled | Whether logging to [Comet](https://www.comet.com/site/) is enabled. | `false` | +| workspace | Comet workspace name. | `None` | +| project | Comet project name. | `None` | +| samples_log_interval | Metrics will be submitted to Comet after processing every `samples_log_intervas` samples. | `100` | +| experiment_name | The name for comet experiment to be used for logging. | `None` | +| api_key | Comet API key. It's not recommended to save the Comet API Key in code. | `None` | +| experiment_key | The key for comet experiment to be used for logging. Must be an alphanumeric string whose length is between 32 and 50 characters. | `None` | +| online | If True, the data will be logged to Comet server, otherwise it will be stored locally in offline experiment. Default is `True`. | `None` | +| mode | Control how the Comet experiment is started. "get": Continue logging to an existing experiment identified by the `experiment_key` value. "create": Always creates of a new experiment, useful for HPO sweeps. "get_or_create" (default): Starts a fresh experiment if required, or persists logging to an existing one. | `None` | + + +Example of **comet** configuration: + +```json +"comet": { + "enabled": true, + "workspace": "my_workspace", + "project": "my_project", + "samples_log_interval": 50, + "experiment_name": "llama-fine-tuning", + "experiment_key": "0c4a1c4a90664f2a8084e600b19a9d7", + "online": false, + "mode": "get", +} +``` + **csv_monitor**: [dictionary] | Fields | Value |Default | diff --git a/docs/_tutorials/monitor.md b/docs/_tutorials/monitor.md index a9c111f8eeec..572e3f4558a7 100644 --- a/docs/_tutorials/monitor.md +++ b/docs/_tutorials/monitor.md @@ -11,7 +11,7 @@ In this tutorial, we introduce the DeepSpeed Monitor and provide examples of its ## Overview -Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), and simple CSV files. +Monitoring model and system metrics during training is vital to ensure hardware resources are fully utilized. The DeepSpeed Monitor enables live logging of metrics through one or more monitoring backends such as PyTorch's [TensorBoard](https://pytorch.org/docs/1.8.0/tensorboard.html), [WandB](https://docs.wandb.ai/quickstart), [Comet](https://www.comet.com/site/?utm_source=deepseed&utm_medium=docs&utm_content=tutorial) and simple CSV files. Below is a live monitoring view for TensorBoard: @@ -21,16 +21,20 @@ Below is a live monitoring view for WandB: ![WandB Example Output](/assets/images/wandb_monitor.PNG){: .align-center} +Below is a live monitoring view for Comet: + +![CometML Example Output](/assets/images/comet_monitor.png){: .align-center} + ## Usage -The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. +The DeepSpeed Monitor is configured within the deepspeed [configuration file](/docs/config-json/#monitoring-module). DeepSpeed will automatically monitor key training metrics, including those tracked with the `wall_clock_breakdown` configuration option. In addition, users can log their own custom events and metrics. - [Automatic Monitoring](#automatic-monitoring) - [Custom Monitoring](#custom-monitoring) ### Automatic Monitoring -When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module-tensorboard-wandb-csv). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module-tensorboard-wandb-csv) for details. +When using DeepSpeed for model training, the Monitor can be configured in the DeepSpeed [configuration file](/docs/config-json/#monitoring-module). No explicit API calls are needed to use the Monitor. The Monitor can be enabled by adding the following field to DeepSpeed's configuration json file. Refer to [Monitoring](/docs/config-json/#monitoring-module) for details. ```json { @@ -45,6 +49,11 @@ When using DeepSpeed for model training, the Monitor can be configured in the De "group": "my_group", "project": "my_project" } + "comet": { + "enabled": true, + "project": "my_project", + "experiment_name": "my_experiment" + } "csv_monitor": { "enabled": true, "output_path": "output/ds_logs/", diff --git a/docs/assets/images/comet_monitor.png b/docs/assets/images/comet_monitor.png new file mode 100644 index 000000000000..83564cd5f1eb Binary files /dev/null and b/docs/assets/images/comet_monitor.png differ diff --git a/docs/code-docs/source/monitor.rst b/docs/code-docs/source/monitor.rst index d286af23f09e..694c72b9b870 100644 --- a/docs/code-docs/source/monitor.rst +++ b/docs/code-docs/source/monitor.rst @@ -29,6 +29,11 @@ WandB .. _WandbConfig: .. autopydantic_model:: deepspeed.monitor.config.WandbConfig +Comet +----- +.. _CometConfig: +.. autopydantic_model:: deepspeed.monitor.config.CometConfig + CSV Monitor ----------- .. _CSVConfig: diff --git a/op_builder/hpu/fused_adam.py b/op_builder/hpu/fused_adam.py index d77228317ddb..5acb121668e3 100644 --- a/op_builder/hpu/fused_adam.py +++ b/op_builder/hpu/fused_adam.py @@ -4,10 +4,88 @@ # DeepSpeed Team -from .builder import CPUOpBuilder +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 # type: ignore + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder +try: + import torch + import math +except ImportError as e: + pass -class FusedAdamBuilder(CPUOpBuilder): + +class HPUFusedAdam: + htcore = None + is_lazy_mode = None + + @staticmethod + def multi_tensor_adam(chunk_size, noop_flag_buffer, tensor_lists, lr, beta1, beta2, epsilon, step, adam_w_mode, + bias_correction, weight_decay, *args): + + if HPUFusedAdam.htcore is None: + from habana_frameworks.torch import core as htcore + from habana_frameworks.torch.utils.internal import is_lazy + HPUFusedAdam.htcore = htcore + HPUFusedAdam.is_lazy_mode = is_lazy() + + htcore = HPUFusedAdam.htcore + + htcore.step_closure._mark_step_if_lazy() + step_size = lr + if bias_correction: + bias_correction1 = 1.0 - pow(beta1, step) + bias_correction2 = 1.0 - pow(beta2, step) + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + neg_step = -step_size + neg_step_t = (torch.tensor([neg_step], dtype=torch.float, + requires_grad=False).to(tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, + non_blocking=True)) + + weight_decay = weight_decay if adam_w_mode else 0 + + # since lr is fed into the kernel as tensor, perform the scalar multiplication of wd here + # NOTE: TODO if lr is updated every step, then we need to convert it as tensor and + # perform weight decay unconditonally. + modified_wd = 1.0 - weight_decay * lr + + if HPUFusedAdam.is_lazy_mode: + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd, + ) + else: + modified_wd_t = (torch.tensor([modified_wd], dtype=torch.float, requires_grad=False).to( + tensor_lists[1][0].dtype).to(tensor_lists[1][0].device, non_blocking=True)) + torch.ops.hpu.optimizer_adamw( + tensor_lists[0], + tensor_lists[1], + tensor_lists[2], + tensor_lists[3], + neg_step_t, + beta1, + beta2, + epsilon, + modified_wd_t, + modified_wd != 1.0, + ) + + htcore.step_closure._mark_step_if_lazy() + + +class FusedAdamBuilder(OpBuilder): BUILD_VAR = "DS_BUILD_FUSED_ADAM" NAME = "fused_adam" @@ -18,12 +96,10 @@ def absolute_name(self): return f'deepspeed.ops.adam.{self.NAME}_op' def sources(self): - return ['csrc/cpu/adam/fused_adam.cpp', 'csrc/adam/cpu_adam_impl.cpp'] - - def cxx_args(self): - args = super().cxx_args() - args += ['-DENABLE_BFLOAT16'] - return args + return [] def include_paths(self): - return ['csrc/includes'] + return [] + + def load(self, verbose=True): + return HPUFusedAdam diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index eb6bfc811e85..c0fc5dba9d33 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -1,5 +1,6 @@ accelerate clang-format==16.0.2 +comet_ml>=3.41.0 deepspeed-kernels ; sys_platform == 'linux' docutils<0.18 future diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 80c9f9b3287a..05f88337f3a9 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,10 +1,10 @@ hjson ninja numpy +nvidia-ml-py packaging>=20.0 psutil py-cpuinfo pydantic -pynvml torch tqdm diff --git a/tests/unit/elasticity/test_elastic.py b/tests/unit/elasticity/test_elastic.py index 63633a51914b..92e1520b2c7c 100644 --- a/tests/unit/elasticity/test_elastic.py +++ b/tests/unit/elasticity/test_elastic.py @@ -150,6 +150,7 @@ def test_proper_mbsz(ds_config): class TestNonElasticBatchParams(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") def test(self): config_dict = { "train_batch_size": 2, @@ -182,6 +183,7 @@ def test(self): class TestNonElasticBatchParamsWithOverride(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") def test(self): if not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME]: pytest.skip("This op had not been implemented on this system.", allow_module_level=True) @@ -215,6 +217,7 @@ def test(self): class TestElasticConfigChanged(DistributedTest): world_size = 2 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") def test(self): config_dict = { "train_batch_size": 2, diff --git a/tests/unit/moe/test_moe.py b/tests/unit/moe/test_moe.py index d39f9fe3d651..fdff9430a4e6 100644 --- a/tests/unit/moe/test_moe.py +++ b/tests/unit/moe/test_moe.py @@ -177,7 +177,7 @@ class TestTopk(DistributedTest): world_size = 2 def test(self): - device = get_accelerator().current_device() + device = get_accelerator().current_device_name() if dist.get_rank() == 0: logits = torch.rand(2, 2, device=device) elif dist.get_rank() == 1: diff --git a/tests/unit/monitor/test_monitor.py b/tests/unit/monitor/test_monitor.py index 3e04bebfb6c1..d4b3cf43921d 100644 --- a/tests/unit/monitor/test_monitor.py +++ b/tests/unit/monitor/test_monitor.py @@ -7,10 +7,14 @@ from deepspeed.monitor.wandb import WandbMonitor from deepspeed.monitor.csv_monitor import csvMonitor from deepspeed.monitor.config import DeepSpeedMonitorConfig +from deepspeed.monitor.comet import CometMonitor from unit.common import DistributedTest +from unittest.mock import Mock, patch from deepspeed.runtime.config import DeepSpeedConfig +import deepspeed.comm as dist + class TestTensorBoard(DistributedTest): world_size = 2 @@ -97,3 +101,66 @@ def test_empty_csv_monitor(self): assert csv_monitor.enabled == defaults.enabled assert csv_monitor.output_path == defaults.output_path assert csv_monitor.job_name == defaults.job_name + + +class TestCometMonitor(DistributedTest): + world_size = 2 + + def test_comet_monitor(self): + import comet_ml + mock_experiment = Mock() + mock_start = Mock(return_value=mock_experiment) + + config_dict = { + "train_batch_size": 2, + "comet": { + "enabled": True, + "samples_log_interval": 42, + "workspace": "some-workspace", + "project": "some-project", + "api_key": "some-api-key", + "experiment_name": "some-experiment-name", + "experiment_key": "some-experiment-key", + "mode": "get_or_create", + "online": True + } + } + + ds_config = DeepSpeedConfig(config_dict) + + with patch.object(comet_ml, "start", mock_start): + comet_monitor = CometMonitor(ds_config.monitor_config.comet) + + assert comet_monitor.enabled is True + assert comet_monitor.samples_log_interval == 42 + + # experiment should be initialized via comet_ml.start only if rank == 0 + if dist.get_rank() == 0: + mock_start.assert_called_once_with( + api_key="some-api-key", + project="some-project", + workspace="some-workspace", + experiment_key="some-experiment-key", + mode="get_or_create", + online=True, + ) + + mock_experiment.set_name.assert_called_once_with("some-experiment-name") + assert comet_monitor.experiment is mock_experiment + else: + mock_start.assert_not_called() + + def test_empty_comet(self): + import comet_ml + mock_start = Mock() + + config_dict = {"train_batch_size": 2, "comet": {}} + ds_config = DeepSpeedConfig(config_dict) + + with patch.object(comet_ml, "start", mock_start): + comet_monitor = CometMonitor(ds_config.monitor_config.comet) + + defaults = DeepSpeedMonitorConfig().comet + assert comet_monitor.enabled == defaults.enabled + assert comet_monitor.samples_log_interval == defaults.samples_log_interval + mock_start.assert_not_called() diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 9a6ff6689446..785cf786acc3 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -11,7 +11,7 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.adam import FusedAdam -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder from unit.common import DistributedTest if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: @@ -62,6 +62,8 @@ class TestCPUAdam(DistributedTest): set_dist_env = False @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test_fused_adam_equal(self, dtype, model_size): if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") diff --git a/tests/unit/ops/adam/test_hybrid_adam.py b/tests/unit/ops/adam/test_hybrid_adam.py index c7ef4890b322..9003e02588c1 100644 --- a/tests/unit/ops/adam/test_hybrid_adam.py +++ b/tests/unit/ops/adam/test_hybrid_adam.py @@ -12,7 +12,7 @@ import deepspeed from deepspeed.accelerator import get_accelerator from deepspeed.ops.adam import FusedAdam, DeepSpeedCPUAdam -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedAdamBuilder from unit.common import DistributedTest if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: @@ -43,6 +43,8 @@ class TestHybridAdam(DistributedTest): set_dist_env = False @pytest.mark.skipif(not get_accelerator().is_available(), reason="only supported in CUDA environments.") + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test_hybrid_adam_equal(self, dtype, model_size): if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") diff --git a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py index f350e08e68a7..38c539c1cc6c 100644 --- a/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py +++ b/tests/unit/runtime/half_precision/test_dynamic_loss_scale.py @@ -10,6 +10,7 @@ import numpy as np from unit.common import DistributedTest from unit.simple_model import SimpleModel +from deepspeed.ops.op_builder import FusedLambBuilder def run_model_step(model, gradient_list): @@ -152,6 +153,7 @@ def test_some_overflow(self): assert optim.cur_iter == expected_iteration +@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") class TestUnfused(DistributedTest): world_size = 1 diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index 5b300053d2a8..cf7a1d8a8183 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -12,7 +12,7 @@ from unit.simple_model import SimpleModel, SimpleOptimizer, random_dataloader, SimpleMoEModel, sequence_dataloader from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator -from deepspeed.ops.op_builder import CPUAdamBuilder +from deepspeed.ops.op_builder import CPUAdamBuilder, FusedLambBuilder from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer try: @@ -22,7 +22,11 @@ _amp_available = False amp_available = pytest.mark.skipif(not _amp_available, reason="apex/amp is not installed") +if torch.half not in get_accelerator().supported_dtypes(): + pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) + +@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") class TestLambFP32GradClip(DistributedTest): world_size = 2 @@ -55,6 +59,7 @@ def test(self): model.step() +@pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") class TestLambFP16(DistributedTest): world_size = 2 @@ -231,6 +236,7 @@ def mock_unscale_and_clip_grads(grads_groups_flat, total_norm, apply_scale=True) engine.backward(loss) engine.step() + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") @pytest.mark.parametrize("fused_lamb_legacy", [(False), (True)]) def test_lamb_gradnorm(self, monkeypatch, fused_lamb_legacy: bool): if not get_accelerator().is_fp16_supported(): @@ -495,6 +501,7 @@ def test_adam_basic(self): model.backward(loss) model.step() + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], reason="lamb is not compatible") def test_lamb_basic(self): if not get_accelerator().is_fp16_supported(): pytest.skip("fp16 is not supported") diff --git a/tests/unit/runtime/test_ds_initialize.py b/tests/unit/runtime/test_ds_initialize.py index 169096a6d4e5..9ff99f169f7a 100644 --- a/tests/unit/runtime/test_ds_initialize.py +++ b/tests/unit/runtime/test_ds_initialize.py @@ -20,6 +20,7 @@ from deepspeed.runtime.utils import see_memory_usage from deepspeed.utils.torch import required_torch_version from deepspeed.accelerator import get_accelerator +from deepspeed.ops.op_builder import FusedAdamBuilder @pytest.mark.parametrize('zero_stage', [0, 3]) @@ -67,6 +68,9 @@ def test(self, optimizer_type): def _optimizer_callable(params) -> Optimizer: return AdamW(params=params) + if (optimizer_type is None) and (not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME]): + pytest.skip("FusedAdam is not compatible") + hidden_dim = 10 model = SimpleModel(hidden_dim) @@ -95,6 +99,8 @@ def _optimizer_callable(params) -> Optimizer: class TestConfigOptimizer(DistributedTest): world_size = 1 + @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedAdamBuilder.NAME], + reason="FusedAdam is not compatible") def test(self, client_parameters): ds_config = {"train_batch_size": 1, "optimizer": {"type": "Adam", "params": {"lr": 0.001}}}