From 5df0d8aae3a4bf83378953d6e1ee007fff61310a Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Thu, 30 May 2024 17:06:55 +0200 Subject: [PATCH 1/7] Warn then the `model.dtype` is different from it's `config.dtype` (#2312) --- src/sparseml/transformers/sparsification/sparse_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/sparseml/transformers/sparsification/sparse_model.py b/src/sparseml/transformers/sparsification/sparse_model.py index 3132411d332..c5e17764874 100644 --- a/src/sparseml/transformers/sparsification/sparse_model.py +++ b/src/sparseml/transformers/sparsification/sparse_model.py @@ -111,6 +111,14 @@ def skip(*args, **kwargs): model = super(AutoModelForCausalLM, cls).from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs ) + if model.dtype != model.config.torch_dtype: + _LOGGER.warning( + f"The dtype of the loaded model: {model.dtype} is different " + "from from the dtype specified in the model config: " + f"{model.config.torch_dtype}." + "To load the model in the format that it was previously saved in, " + "set torch_dtype=`auto` in the SparseAutoModel creation call." + ) logger.setLevel(level=restore_log_level) # override the PreTrainedModel instance with compression save function modify_save_pretrained(model) From 5caa5574cdb03e7d09e0f842b2d350bc102e28e5 Mon Sep 17 00:00:00 2001 From: Eldar Kurtic Date: Mon, 3 Jun 2024 16:40:08 +0200 Subject: [PATCH 2/7] Fix import typo (#2311) Co-authored-by: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> --- examples/llama7b_sparse_quantized/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llama7b_sparse_quantized/README.md b/examples/llama7b_sparse_quantized/README.md index 59a8b98bca6..779696ba599 100644 --- a/examples/llama7b_sparse_quantized/README.md +++ b/examples/llama7b_sparse_quantized/README.md @@ -40,7 +40,7 @@ run the following: ``` import torch -from sparseml import SparseAutoModelForCausalLM +from sparseml.transformers import SparseAutoModelForCausalLM model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16) model.save_pretrained(compressed_output_dir, save_compressed=True) @@ -49,4 +49,4 @@ model.save_pretrained(compressed_output_dir, save_compressed=True) ### Custom Quantization The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`. The above recipe (`2:4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group. -To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml` \ No newline at end of file +To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml` From ef0232e43a1dce62735397a4045b9faf179ac288 Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 4 Jun 2024 09:55:27 -0400 Subject: [PATCH 3/7] update tests (#2314) --- .../sparsification/test_compress_tensor_utils.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py index 6c76e4a9360..a42c87d43e6 100644 --- a/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py @@ -20,11 +20,8 @@ from transformers import AutoConfig import sparseml -from compressed_tensors import ( - COMPRESSION_CONFIG_NAME, - QUANTIZATION_CONFIG_NAME, - SPARSITY_CONFIG_NAME, -) +from compressed_tensors import COMPRESSION_CONFIG_NAME +from compressed_tensors.compressors import ModelCompressor from compressed_tensors.config import BitmaskConfig, DenseSparsityConfig from compressed_tensors.quantization import ( QuantizationStatus, @@ -96,7 +93,7 @@ def test_sparse_model_reload(compressed, config, dtype, tmp_path): config = AutoConfig.from_pretrained(tmp_path / "compress_out") compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) - sparsity_config = compression_config.get(SPARSITY_CONFIG_NAME, None) + sparsity_config = ModelCompressor.parse_sparsity_config(compression_config) assert ( sparsity_config["format"] == "dense" if (not compressed and config is None) @@ -146,7 +143,8 @@ def test_dense_model_save(tmp_path, skip_compression_stats, save_compressed): # for models with 0% sparsity no sparsity config is saved regardless config = AutoConfig.from_pretrained(tmp_path / "dense_out") - sparsity_config = getattr(config, SPARSITY_CONFIG_NAME, None) + compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) + sparsity_config = ModelCompressor.parse_sparsity_config(compression_config) assert sparsity_config is None shutil.rmtree(tmp_path) @@ -203,7 +201,7 @@ def test_quant_model_reload(format, dtype, tmp_path): config = AutoConfig.from_pretrained(tmp_path / "compress_out") compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) - quant_config = compression_config.get(QUANTIZATION_CONFIG_NAME, None) + quant_config = ModelCompressor.parse_quantization_config(compression_config) assert quant_config["format"] == format dense_model = SparseAutoModelForCausalLM.from_pretrained( @@ -273,7 +271,7 @@ def test_quant_infer_format(status, expected_format, expected_dtype, tmp_path): config = AutoConfig.from_pretrained(tmp_path / "compress_out") compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None) - quant_config = compression_config.get(QUANTIZATION_CONFIG_NAME, None) + quant_config = ModelCompressor.parse_quantization_config(compression_config) assert quant_config["quantization_status"] == status.value assert quant_config["format"] == expected_format From 38a1214e484e38a4c17b74078247d571332be9c9 Mon Sep 17 00:00:00 2001 From: dbogunowicz <97082108+dbogunowicz@users.noreply.github.com> Date: Wed, 5 Jun 2024 16:14:45 +0200 Subject: [PATCH 4/7] Check for 2:4 structure when saving `SparseAutoModel` (#2317) * Update sparsity_config.py * Create helpers.py * cleanup --------- Co-authored-by: bogunowicz@arrival.com --- .../transformers/compression/helpers.py | 101 ++++++++++++++++++ .../compression/sparsity_config.py | 40 ++++--- 2 files changed, 128 insertions(+), 13 deletions(-) create mode 100644 src/sparseml/transformers/compression/helpers.py diff --git a/src/sparseml/transformers/compression/helpers.py b/src/sparseml/transformers/compression/helpers.py new file mode 100644 index 00000000000..efd883960fb --- /dev/null +++ b/src/sparseml/transformers/compression/helpers.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +from tqdm import tqdm + +from sparseml.pytorch.utils import get_linear_layers + + +__ALL__ = [ + "tensor_follows_mask_structure", + "infer_sparsity_structure_from_stage_modifiers", + "infer_sparsity_structure_from_model", +] + + +def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool: + """ + :param tensor: tensor to check + :param mask: mask structure to check for, in the format "n:m" + :return: True if the tensor follows the mask structure, False otherwise. + Note, some weights can incidentally be zero, so we check for + atleast n zeros in each chunk of size m + """ + + n, m = tuple(map(int, mask.split(":"))) + # Reshape the tensor into chunks of size m + tensor = tensor.view(-1, m) + + # Count the number of zeros in each chunk + zero_counts = (tensor == 0).sum(dim=1) + + # Check if the number of zeros in each chunk atleast n + # Greater than sign is needed as some weights can incidentally + # be zero + return torch.all(zero_counts >= n).item() + + +def infer_sparsity_structure_from_stage_modifiers( + stage_modifiers: List["StageModifier"], # noqa E501 +) -> Optional[str]: + """ + Determines the sparsity structure, if any exists, given the + list of stage modifiers + + :param stage_modifiers: non-empty list of stage modifiers + :return: sparsity structure as a string or None + """ + for stage in stage_modifiers: + if stage.applied: + for modifier in stage.modifiers: + if hasattr(modifier, "mask_structure"): + sparsity_structure = modifier.mask_structure + return sparsity_structure + return None + + +def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]: + """ + Determines the sparsity structure, if any exists, given the model + + :param model: model to check for sparsity structure + :return: sparsity structure as a string or None + """ + + # check for the common sparsity structures + structures = {"2:4"} + for sparsity_structure in structures: + linear_modules = get_linear_layers(model) + linear_modules_with_sparsity_structure = [ + tensor_follows_mask_structure(layer.weight) + for layer in tqdm( + linear_modules.values(), + desc="Checking whether model follows " + f"{sparsity_structure} sparsity structure", + ) + ] + # if the majority of the linear modules follow the sparsity structure + # we can assume that the model follows the sparsity structure + # (taking into consideration the fact that some Linear layers like the + # embedding layer might not be sparse) + if ( + sum(linear_modules_with_sparsity_structure) + > len(linear_modules_with_sparsity_structure) * 0.8 + ): + return sparsity_structure + + return None diff --git a/src/sparseml/transformers/compression/sparsity_config.py b/src/sparseml/transformers/compression/sparsity_config.py index b5f69cb83e1..f8bc477366c 100644 --- a/src/sparseml/transformers/compression/sparsity_config.py +++ b/src/sparseml/transformers/compression/sparsity_config.py @@ -21,6 +21,10 @@ from compressed_tensors import CompressionFormat, SparsityCompressionConfig from compressed_tensors.quantization.utils import is_model_quantized from sparseml.pytorch.utils import ModuleSparsificationInfo +from sparseml.transformers.compression.helpers import ( + infer_sparsity_structure_from_model, + infer_sparsity_structure_from_stage_modifiers, +) class SparsityConfigMetadata: @@ -47,26 +51,34 @@ def infer_global_sparsity( return global_sparsity @staticmethod - def infer_sparsity_structure() -> str: + def infer_sparsity_structure(model: Optional[Module] = None) -> str: """ - Determines what sparsity structure, if any, was applied in the currently active - sparse session + Determines what sparsity structure, if any, was applied. + + First, there is an attempt to dedue the sparsity structure + from the currently active sparse session. + + If that fails, the sparsity structure is inferred from the + model (if provided) + + Finally, if both fail, the sparsity structure is set to + "unstructured" :return: sparsity structure as a string """ + sparsity_structure = None + current_session = sparseml.active_session() stage_modifiers = current_session.lifecycle.modifiers - sparsity_structure = "unstructured" + if stage_modifiers: + sparsity_structure = infer_sparsity_structure_from_stage_modifiers( + stage_modifiers + ) - # check for applied pruning modifiers - for stage in stage_modifiers: - if stage.applied: - for modifier in stage.modifiers: - if hasattr(modifier, "mask_structure"): - sparsity_structure = modifier.mask_structure - break + if model and sparsity_structure is None: + sparsity_structure = infer_sparsity_structure_from_model(model) - return sparsity_structure + return sparsity_structure or "unstructured" @staticmethod def from_pretrained( @@ -91,7 +103,9 @@ def from_pretrained( if global_sparsity < 0.05: return None - sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure() + sparsity_structure = SparsityConfigMetadata.infer_sparsity_structure( + model=model + ) if is_model_quantized(model): # compressing a sparse quantized model is not supported yet format = CompressionFormat.dense.value From 3cd9a8ce512da119a8cd47a50a75a4dee33d83d2 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Fri, 7 Jun 2024 12:45:31 -0400 Subject: [PATCH 5/7] fix save path in llama7b_w4a16_quantization.ipynb (#2321) --- examples/llama7b_w4a16_quantization.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama7b_w4a16_quantization.ipynb b/examples/llama7b_w4a16_quantization.ipynb index 194215891fa..4ee88ff0b05 100644 --- a/examples/llama7b_w4a16_quantization.ipynb +++ b/examples/llama7b_w4a16_quantization.ipynb @@ -153,7 +153,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.save_pretrained(\"/network/sadkins/llama1.1b_W4A16_channel_packed\", save_compressed=True)" + "model.save_pretrained(\"llama1.1b_W4A16_channel_packed\", save_compressed=True)" ] } ], From 934f0d8b9b12845fa9b82fed87d4b54cdfec7a3d Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Mon, 10 Jun 2024 11:39:47 -0400 Subject: [PATCH 6/7] Update Quantization Logging to New Framework (#2313) * use new quant framework for logging * fix legacy compatability * fix --- src/sparseml/pytorch/utils/helpers.py | 31 ++++++------------- .../transformers/finetune/session_mixin.py | 11 +++++-- 2 files changed, 18 insertions(+), 24 deletions(-) diff --git a/src/sparseml/pytorch/utils/helpers.py b/src/sparseml/pytorch/utils/helpers.py index 4b495afe497..e9c603355de 100644 --- a/src/sparseml/pytorch/utils/helpers.py +++ b/src/sparseml/pytorch/utils/helpers.py @@ -20,7 +20,6 @@ import os import random import re -import warnings from collections import OrderedDict, namedtuple from contextlib import contextmanager from copy import deepcopy @@ -30,7 +29,7 @@ import torch from packaging import version from torch import Tensor -from torch.nn import Linear, Module, Parameter +from torch.nn import Embedding, Linear, Module, Parameter from torch.nn.modules.conv import Conv2d, Conv3d, _ConvNd from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader @@ -780,6 +779,7 @@ def get_prunable_layers(module: Module) -> List[Tuple[str, Module]]: for (name, mod) in module.named_modules() if ( isinstance(mod, Linear) + or isinstance(mod, Embedding) or isinstance(mod, _ConvNd) or (QATLinear and isinstance(mod, QATLinear)) or (QATConv2d and isinstance(mod, QATConv2d)) @@ -793,7 +793,7 @@ def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]: """ :param module: the module to get the quantizable layers from :return: a list containing the names and modules of the quantizable layers - (Linear, Conv2d, Conv3d) + (Embedding, Linear, Conv2d, Conv3d) """ if QATLinear is None: raise ImportError( @@ -806,6 +806,7 @@ def get_quantizable_layers(module: Module) -> List[Tuple[str, Module]]: for (name, mod) in module.named_modules() if ( isinstance(mod, Linear) + or isinstance(mod, Embedding) or isinstance(mod, Conv2d) or (QATConv3d and isinstance(mod, Conv3d)) ) @@ -816,29 +817,15 @@ def get_quantized_layers(module: Module) -> List[Tuple[str, Module]]: """ :param module: the module to get the quantized layers from :return: a list containing the names and modules of the quantized layers - (Linear, Conv2d, Conv3d) + (Embedding, Linear, Conv2d, Conv3d) """ - if QATLinear is None: - raise ImportError( - "PyTorch version is not setup for Quantization. " - "Please install a QAT compatible version of PyTorch" - ) quantized_layers = [] for (name, mod) in module.named_modules(): - if ( - (QATLinear and isinstance(mod, QATLinear)) - or (QATConv2d and isinstance(mod, QATConv2d)) - or (QATConv3d and isinstance(mod, QATConv3d)) - ): - quantized_layers.append((name, mod)) - - elif isinstance(mod, Conv3d) and not QATConv3d: - warnings.warn( - "Pytorch version is not setup for Conv3D Quantization. " - "Quantization of Conv3D layers will be skipped", - UserWarning, - ) + if hasattr(mod, "quantization_scheme"): + weight_scheme = getattr(mod.quantization_scheme, "weights", None) + if weight_scheme is not None and hasattr(mod, "weight"): + quantized_layers.append((name, mod)) return quantized_layers diff --git a/src/sparseml/transformers/finetune/session_mixin.py b/src/sparseml/transformers/finetune/session_mixin.py index 149b59be7cd..7436261980e 100644 --- a/src/sparseml/transformers/finetune/session_mixin.py +++ b/src/sparseml/transformers/finetune/session_mixin.py @@ -500,15 +500,22 @@ def log_model_sparsification(self): f"Sparsification info for {type(self.model).__name__}: " f"{sparsification_info.params_total} total params. " ) + sparsity_percent_formatted = "{:.2f}".format( + sparsification_info.params_prunable_sparse_percent + ) _LOGGER.info( f"There are {sparsification_info.params_prunable_total} prunable " - f"params which have {sparsification_info.params_prunable_sparse_percent} " + f"params which have {sparsity_percent_formatted}% " "avg sparsity." ) + + quant_percent_formatted = "{:.2f}".format( + sparsification_info.params_quantized_percent + ) _LOGGER.info( f"There are {sparsification_info.params_quantizable} quantizable " f"params, with a quantization percentage of " - f"{sparsification_info.params_quantized_percent}." + f"{quant_percent_formatted}%." ) def _prepare_model_for_fsdp(self): From e255b17765add46053a2669086cbc95b3fff406c Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 11 Jun 2024 15:04:28 -0400 Subject: [PATCH 7/7] Fix for Sparsity Persist (#2323) * fix sparsity persist * helper moved to compressed-tensors --- .../quantization/gptq/utils/gptq_wrapper.py | 43 +++++++++---------- .../obcq/test_mask_structure_preservation.py | 24 +---------- 2 files changed, 21 insertions(+), 46 deletions(-) diff --git a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py index 73321c0d0aa..ded28b4123b 100644 --- a/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py +++ b/src/sparseml/modifiers/quantization/gptq/utils/gptq_wrapper.py @@ -103,6 +103,14 @@ def fasterprune( W = W.t() W = W.float() + sparsity = tensor_sparsity(W) + preserve_zeros = sparsity >= SPARSITY_THRESHOLD + W_nz_mask = ( + (~torch.isclose(W, torch.zeros(1, device=W.device).float())).float() + if preserve_zeros + else None + ) + tick = time.time() dead = torch.diag(self.H) == 0 @@ -119,17 +127,6 @@ def fasterprune( self.H = torch.linalg.cholesky(self.H, upper=True) Hinv = self.H - sparsity = tensor_sparsity(W) - mask = ( - torch.where( - W == 0, - torch.tensor(1, dtype=torch.bool), - torch.tensor(0, dtype=torch.bool), - ) - if sparsity >= SPARSITY_THRESHOLD - else None - ) - # See section 3.4 of https://arxiv.org/abs/2203.07259 for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) @@ -141,21 +138,13 @@ def fasterprune( Losses1 = torch.zeros_like(W1) Hinv1 = Hinv[i1:i2, i1:i2] - if sparsity >= SPARSITY_THRESHOLD: - tmp = ( - (~mask[:, i1:i2]) - * W1**2 - / (torch.diag(Hinv1).reshape((1, -1))) ** 2 - ) - thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] - mask1 = tmp <= thresh + if preserve_zeros: + W1_nz_mask = W_nz_mask[:, i1:i2] for i in range(count): w = W1[:, i] d = Hinv1[i, i] q = w.clone() - if sparsity >= SPARSITY_THRESHOLD: - q[mask1[:, i]] = 0 if hasattr(self.layer, "weight_fake_quant"): scale = self.layer.weight_fake_quant.scale @@ -216,13 +205,21 @@ def fasterprune( Losses1[:, i] = (w - q) ** 2 / d**2 err1 = (w - q) / d - W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + w1_err = err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if preserve_zeros: + W1[:, i:] -= w1_err * W1_nz_mask[:, i:] + else: + W1[:, i:] -= w1_err Err1[:, i] = err1 W[:, i1:i2] = Q1 Losses += torch.sum(Losses1, 1) / 2 - W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + w_err = Err1.matmul(Hinv[i1:i2, i2:]) + if preserve_zeros: + W[:, i2:] -= w_err * W_nz_mask[:, i2:] + else: + W[:, i2:] -= w_err _LOGGER.info("time %.2f" % (time.time() - tick)) _LOGGER.info("error %.2f" % torch.sum(Losses).item()) diff --git a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py index a068c391431..eca6f5d2379 100644 --- a/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py +++ b/tests/sparseml/transformers/obcq/test_mask_structure_preservation.py @@ -19,6 +19,7 @@ import pytest import sparseml +from compressed_tensors.compressors.utils import tensor_follows_mask_structure from parameterized import parameterized_class from tests.testing_utils import parse_params, requires_torch @@ -28,29 +29,6 @@ ) -def tensor_follows_mask_structure(tensor, mask: str = "2:4"): - """ - :param tensor: tensor to check - :param mask: mask structure to check for, in the format "n:m" - :return: True if the tensor follows the mask structure, False otherwise. - Note, some weights can incidentally be zero, so we check for - atleast n zeros in each chunk of size m - """ - import torch - - n, m = tuple(map(int, mask.split(":"))) - # Reshape the tensor into chunks of size m - tensor = tensor.view(-1, m) - - # Count the number of zeros in each chunk - zero_counts = (tensor == 0).sum(dim=1) - - # Check if the number of zeros in each chunk atleast n - # Greater than sign is needed as some weights can incidentally - # be zero - return torch.all(zero_counts >= n) - - @requires_torch @pytest.mark.integration @parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY))