diff --git a/src/compressed_tensors/quantization/observers/helpers.py b/src/compressed_tensors/quantization/observers/helpers.py index 875a05b3..ec474303 100644 --- a/src/compressed_tensors/quantization/observers/helpers.py +++ b/src/compressed_tensors/quantization/observers/helpers.py @@ -13,7 +13,7 @@ # limitations under the License. from collections import Counter -from typing import Optional, Tuple +from typing import Tuple import torch from compressed_tensors.quantization.quant_args import ( diff --git a/src/compressed_tensors/utils/helpers.py b/src/compressed_tensors/utils/helpers.py index e1587ada..82e11ccf 100644 --- a/src/compressed_tensors/utils/helpers.py +++ b/src/compressed_tensors/utils/helpers.py @@ -24,6 +24,7 @@ "tensor_follows_mask_structure", "replace_module", "is_compressed_tensors_config", + "getattr_chain", ] FSDP_WRAPPER_NAME = "_fsdp_wrapped_module" @@ -119,3 +120,35 @@ def is_compressed_tensors_config(compression_config: Any) -> bool: return isinstance(compression_config, CompressedTensorsConfig) except ImportError: return False + + +def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: + """ + Chain multiple getattr calls, separated by `.` + + :param obj: base object whose attributes are being retrieved + :param chain_str: attribute names separated by `.` + :param default: default value, throw error otherwise + + """ + if len(args) >= 1: + has_default = True + default = args[0] + elif "default" in kwargs: + has_default = True + default = kwargs["default"] + else: + has_default = False + + attr_names = chain_str.split(".") + + res = obj + for attr_name in attr_names: + if not hasattr(res, attr_name): + if has_default: + return default + else: + raise AttributeError(f"{res} object has no attribute {attr_name}") + res = getattr(res, attr_name) + + return res diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 9dd7b22d..b5a5fe0c 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -12,8 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +from functools import wraps +from typing import Optional + import torch -from torch.nn import Module +import warnings +from compressed_tensors.utils.helpers import getattr_chain + + +try: + from accelerate.hooks import AlignDevicesHook + from accelerate.utils import OffloadedWeightsLoader, PrefixedDataset, set_module_tensor_to_device + _has_accelerate = True +except ImportError: + _has_accelerate = False + AlignDevicesHook = None + OffloadedWeightsLoader = None + PrefixedDataset = None __all__ = [ @@ -25,18 +41,40 @@ ] -def is_module_offloaded(module: Module) -> bool: +# upstream candidate +def has_offloaded_params(module: torch.nn.Module) -> bool: """ - :param module: layer to check - :return: True if layer is offloaded from GPU, False otherwise + Checks if a module has offloaded parameters by checking if the given module + has a AlignDevicesHook attached with offloading enabled + + Args: + module (`torch.nn.Module`): The module to check for an offload hook. + + Returns: + bool: `True` if the module has an offload hook and offloading is enabled, + `False` otherwise. """ - return hasattr(module, "_hf_hook") and module._hf_hook.offload + return ( + hasattr(module, "_hf_hook") and + isinstance(module._hf_hook, AlignDevicesHook) and + module._hf_hook.offload + ) + +# depreciation candidate +@wraps(has_offloaded_params) +def is_module_offloaded(module: torch.nn.Module) -> bool: + if not _has_accelerate: + return False -def get_execution_device(module: Module) -> torch.device: + return has_offloaded_params(module) + + +# depreciation candidate +def get_execution_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is loaded onto during forward pass + :param module: module to check + :return: device module is loaded onto during forward pass """ if is_module_offloaded(module): return module._hf_hook.execution_device @@ -49,26 +87,36 @@ def get_execution_device(module: Module) -> torch.device: return device -def get_offloaded_device(module: Module) -> torch.device: +# upstream candidate +def _infer_offload_device(module: torch.nn.Module) -> torch.device: + if not has_offloaded_params(module): + raise ValueError("Cannot infer offload device from non-offloaded module") + + first_key = next(module._hf_hook.weights_map.keys(), None) + if first_key is None: + raise ValueError("Cannot infer offload device from empty weights map") + + prefix_dataset = module._hf_hook.weights_map.dataset + return prefix_dataset[first_key].device + +# depreciation candidate +def get_offloaded_device(module: torch.nn.Module) -> torch.device: """ - :param module: layer to check - :return: device layer is offloaded to onto after forward pass + :param module: module to check + :return: device module is offloaded to onto after forward pass """ - if is_module_offloaded(module): - first_key = list(module._hf_hook.weights_map.keys())[0] - prefix_dataset = module._hf_hook.weights_map.dataset - return prefix_dataset[first_key].device - return next(module.parameters()).device + return _infer_offload_device(module) -def update_prefix_dict(module: Module, key: str, data: torch.Tensor): +# depreciation candidate +def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor): """ Updates the offloaded state dict for a given module. Parameter named key is replaced by data. This is neccesary because parameter updates for offloaded modules do not persist automatically between loads. This function only affects the offloaded state dict and not the current state of the loaded module. - :param module: layer containing the parameter to update + :param module: module containing the parameter to update :param key: name of parameter to update :param data: tensor to update parameter with in the offloaded state dict """ @@ -78,39 +126,153 @@ def update_prefix_dict(module: Module, key: str, data: torch.Tensor): prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data +# upstream candidate? +def update_offload_parameter( + module: torch.nn.Module, + name: str, + data: Optional[torch.Tensor] = None, + offload_device: Optional[torch.device] = None, +): + """ + :param module: module containing the parameter to update + :param name: name of module parameter to update + :param data: tensor to update parameter with + :param offload_device: offload device for newly registered parameters + """ + param = getattr(module, name) + if param.device == "meta" or data is not None and data.device == "meta": + raise ValueError("Cannot copy data to/from meta device. Consider calling with align_module(module)") + + if data is not None: + if param.data.dtype != data.dtype: + warnings.warn("TODO") + + param.data.copy_(data) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + prefix_dict = getattr_chain(module, "module._hf_hook.weights_map.dataset", None) + if prefix_dict is not None: + prefix = module._hf_hook.weights_map.prefix + key = f"{prefix}{name}" + + offload_device = ( + prefix_dict[key].device if key in prefix_dict + else offload_device if offload_device is not None + else _infer_offload_device(module) + ) + prefix_dict[key] = param.data.to(device=offload_device) + + if isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + else: + raise NotImplementedError() + +# depreciation candidate def update_parameter_data( - module: Module, new_param_data: torch.Tensor, param_name: str + module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str ): + param = getattr(module, param_name) + new_param_data = new_param_data.to(device=param.device, dtype=param.dtype) + update_offload_parameter(module, param_name, new_param_data) + + +@contextlib.contextmanager +def align_module(module: torch.nn.Module, execution_device: Optional[torch.device] = None): """ - Updates the paramter value named param_name for a given module. This function - updates both the current loaded module state and the offloaded state dict if - the module is offloaded. This is neccesary because parameter updates for offloaded - modules do not persist automatically between loads. - - :param module: layer containing the parameter to update - :param new_param_data: tensor to update parameter with - :param param_name: name of layer parameter to update + Moves a module's parameters to the specified execution device. + + Args: + module (torch.nn.Module): Module with parameters to align. + execution_device (Optional[torch.device]): If provided, overrides the + module's execution device within the context. + + Yields: + None: Yields control while the module's parameters are aligned to the execution device. """ - if not hasattr(module, param_name): - return + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device - device = next(module.parameters()).device + module._hf_hook.pre_forward(module) + yield + module._hf_hook.post_forward(module, None) - offloaded = False - if is_module_offloaded(module): - offload_device = get_offloaded_device(module) - offloaded = True - - parameter = getattr(module, param_name, None) - if parameter is None: - raise ValueError("Attempted to update uninitialized parameter") - - dtype = parameter.dtype - parameter.data = new_param_data.to(device).to(dtype) - - if offloaded: - prefix_dict = module._hf_hook.weights_map.dataset - prefix = module._hf_hook.weights_map.prefix - prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to( - dtype - ) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = {} + for name, param in module.named_parameters(): + devices[name] = param.device + set_module_tensor_to_device( + module, + name, + execution_device, + ) + + yield + + for name, param in module.named_parameters(): + set_module_tensor_to_device( + module, + name, + devices[name], + ) + + else: + yield + + + +@contextlib.contextmanager +def modify_offload_module( + module: torch.nn.Module, + execution_device: Optional[torch.device] = None, + offload_device: Optional[torch.device] = None, +): + with align_module(module, execution_device): + yield + + # there is little performance gain from checking if a parameter's data + # has been modified before copying since the new data must be copied + # to the offload device anyways; just update all module parameters + for name, param in module.named_parameters(): + update_offload_parameter(module, name, param.data, offload_device) + + +# upstream candidate? +def register_offload_parameter( + module: torch.nn.Module, + name: str, + parameter: torch.nn.Parameter, + offload_device: Optional[torch.device] = None, +): + module.register_parameter(name, parameter) + update_offload_parameter(module, name, parameter.data, offload_device) + + +# upstream candidate? +def delete_offload_parameter(module: torch.nn.Module, name: str): + delattr(module, name) + + if has_offloaded_params(module): + weights_map = module._hf_hook.weights_map + + # for upstreaming, probably better to modify the weight map types so that they can be written to? + if isinstance(weights_map, PrefixedDataset): + dataset = weights_map.dataset + prefix = weights_map.prefix + if dataset is not None: + del dataset[f"{prefix}{name}"] + + elif isinstance(weights_map, OffloadedWeightsLoader): + raise NotImplementedError() + + elif weights_map is not None: + raise NotImplementedError(f"Cannot delete parameter from weights_map of type {type(weights_map)}") \ No newline at end of file