diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 4b495afe497..e9c603355de 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -20,7 +20,6 @@ import os import random import re -import warnings from collections import OrderedDict, namedtuple from contextlib import contextmanager from copy import deepcopy @@ -30,7 +29,7 @@ import torch from packaging import version from torch import Tensor -from torch.nn import Linear, Module, Parameter +from torch.nn import Embedding, Linear, Module, Parameter from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -780,6 +779,7 @@ def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]: for (name, mod) in module.named_modules() if ( isinstance(mod, Linear) + or isinstance(mod, Embedding) or isinstance(mod, _ConvNd) or (QATLinear and isinstance(mod, QATLinear)) or (QATConv2d and isinstance(mod, QATConv2d)) @@ -793,7 +793,7 @@ def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]: """ :param module: the module to get the quantizable layers from :return: a list containing the names and modules of the quantizable layers - (Linear, Conv2d, Conv3d) + (Embedding, Linear, Conv2d, Conv3d) """ if QATLinear is None: raise ImportError( @@ -806,6 +806,7 @@ def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]: for (name, mod) in module.named_modules() if ( isinstance(mod, Linear) + or isinstance(mod, Embedding) or isinstance(mod, Conv2d) or (QATConv3d and isinstance(mod, Conv3d)) ) @@ -816,29 +817,15 @@ def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]: """ :param module: the module to get the quantized layers from :return: a list containing the names and modules of the quantized layers - (Linear, Conv2d, Conv3d) + (Embedding, Linear, Conv2d, Conv3d) """ - if QATLinear is None: - raise ImportError( - "PyTorch version is not setup for Quantization. " - "Please install a QAT compatible version of PyTorch" - ) quantized_layers = [] for (name, mod) in module.named_modules(): - if ( - (QATLinear and isinstance(mod, QATLinear)) - or (QATConv2d and isinstance(mod, QATConv2d)) - or (QATConv3d and isinstance(mod, QATConv3d)) - ): - quantized_layers.append((name, mod)) - - elif isinstance(mod, Conv3d) and not QATConv3d: - warnings.warn( - "Pytorch version is not setup for Conv3D Quantization. " - "Quantization of Conv3D layers will be skipped", - UserWarning, - ) + if hasattr(mod, "quantization_scheme"): + weight_scheme = getattr(mod.quantization_scheme, "weights", None) + if weight_scheme is not None and hasattr(mod, "weight"): + quantized_layers.append((name, mod)) return quantized_layers diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 149b59be7cd..7436261980e 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -500,15 +500,22 @@ def log_model_sparsification(self): f"Sparsification info for {type(self.model).__name__}: " f"{sparsification_info.params_total} total params. " ) + sparsity_percent_formatted = "{:.2f}".format( + sparsification_info.params_prunable_sparse_percent + ) _LOGGER.info( f"There are {sparsification_info.params_prunable_total} prunable " - f"params which have {sparsification_info.params_prunable_sparse_percent} " + f"params which have {sparsity_percent_formatted}% " "avg sparsity." ) + + quant_percent_formatted = "{:.2f}".format( + sparsification_info.params_quantized_percent + ) _LOGGER.info( f"There are {sparsification_info.params_quantizable} quantizable " f"params, with a quantization percentage of " - f"{sparsification_info.params_quantized_percent}." + f"{quant_percent_formatted}%." ) def _prepare_model_for_fsdp(self):