-
Notifications
You must be signed in to change notification settings - Fork 148
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mask_structure preservation test (#2284)
* test * Preserve weight sparsity if greater than threshold * Add argument to preserve sparsity mask in SPARSEGPT * fix case when mask is none * Add test to check mask_structure - initial mask structure should be preserved b/w consecutive runs; added test to check this * Update tensor_follows_mask_structure to check for atleast n zeros --------- Co-authored-by: Sara Adkins <[email protected]>
- Loading branch information
1 parent
446555f
commit 440661b
Showing
4 changed files
with
211 additions
and
0 deletions.
There are no files selected for viewing
9 changes: 9 additions & 0 deletions
9
.../obcq/obcq_configs/consec_runs/mask_structure/tiny_llama_mask_structure_preservation.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
cadence: "commit" | ||
test_type: "sanity" | ||
model: "Xenova/llama2.c-stories15M" | ||
dataset: open_platypus | ||
initial_pruning_only_recipe: "tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml" | ||
initial_sparsity: 0.5 | ||
recipe_mask_structure: "2:4" | ||
subsequent_prune_and_quant_recipe: "tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml" | ||
final_sparsity: 0.7 |
43 changes: 43 additions & 0 deletions
43
tests/sparseml/transformers/obcq/recipes/additional_sparsity_with_quant.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
test_stage: | ||
obcq_modifiers: | ||
SmoothQuantModifier: | ||
smoothing_strength: 0.5 | ||
mappings: [ | ||
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], | ||
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] | ||
] | ||
QuantizationModifier: | ||
ignore: | ||
- LlamaRotaryEmbedding | ||
- LlamaRMSNorm | ||
- SiLU | ||
- model.layers.0.mlp.down_proj | ||
- model.layers.1.mlp.down_proj | ||
- model.layers.2.mlp.down_proj | ||
- model.layers.3.mlp.down_proj | ||
- model.layers.4.mlp.down_proj | ||
- model.layers.5.mlp.down_proj | ||
post_oneshot_calibration: True | ||
scheme_overrides: | ||
Embedding: | ||
input_activations: null | ||
weights: | ||
num_bits: 8 | ||
symmetric: False | ||
SparseGPTModifier: | ||
sparsity: 0.7 | ||
block_size: 128 | ||
sequential_update: False | ||
percdamp: 0.01 | ||
mask_structure: "0:0" | ||
targets: [ | ||
"model.layers.0", | ||
] | ||
preserve_sparsity_mask: True | ||
GPTQModifier: | ||
sequential_update: False | ||
dampening_frac: 0.01 | ||
targets: [ | ||
"model.layers.0", | ||
] | ||
block_size: 128 |
11 changes: 11 additions & 0 deletions
11
tests/sparseml/transformers/obcq/recipes/sparse_with_mask_structure.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
test_stage: | ||
obcq_modifiers: | ||
SparseGPTModifier: | ||
sparsity: 0.5 | ||
block_size: 128 | ||
sequential_update: False | ||
percdamp: 0.01 | ||
mask_structure: "2:4" | ||
targets: [ | ||
"model.layers.0", | ||
] |
148 changes: 148 additions & 0 deletions
148
tests/sparseml/transformers/obcq/test_mask_structure_preservation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import unittest | ||
from pathlib import Path | ||
|
||
import pytest | ||
|
||
import sparseml | ||
from parameterized import parameterized_class | ||
from tests.testing_utils import parse_params, requires_torch | ||
|
||
|
||
MASK_STRUCTURE_CONFIGS_DIRECTORY = ( | ||
"tests/sparseml/transformers/obcq/obcq_configs/consec_runs/mask_structure" | ||
) | ||
|
||
|
||
def tensor_follows_mask_structure(tensor, mask: str = "2:4"): | ||
""" | ||
:param tensor: tensor to check | ||
:param mask: mask structure to check for, in the format "n:m" | ||
:return: True if the tensor follows the mask structure, False otherwise. | ||
Note, some weights can incidentally be zero, so we check for | ||
atleast n zeros in each chunk of size m | ||
""" | ||
import torch | ||
|
||
n, m = tuple(map(int, mask.split(":"))) | ||
# Reshape the tensor into chunks of size m | ||
tensor = tensor.view(-1, m) | ||
|
||
# Count the number of zeros in each chunk | ||
zero_counts = (tensor == 0).sum(dim=1) | ||
|
||
# Check if the number of zeros in each chunk atleast n | ||
# Greater than sign is needed as some weights can incidentally | ||
# be zero | ||
return torch.all(zero_counts >= n) | ||
|
||
|
||
@requires_torch | ||
@pytest.mark.integration | ||
@parameterized_class(parse_params(MASK_STRUCTURE_CONFIGS_DIRECTORY)) | ||
class TestMaskStructurePreserved(unittest.TestCase): | ||
""" | ||
Tests that the mask structure is preserved across multiple runs of oneshot | ||
initial model is pruned using a mask_structure, and then the pruned model | ||
is further pruned and quantized. | ||
""" | ||
|
||
model = None | ||
initial_pruning_only_recipe = None | ||
initial_sparsity = None | ||
recipe_mask_structure = None | ||
dataset = None | ||
subsequent_prune_and_quant_recipe = None | ||
final_sparsity = None | ||
|
||
def setUp(self) -> None: | ||
import torch | ||
|
||
self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | ||
self.output = "./oneshot_output" | ||
self.output_first = Path(self.output) / "test_1" | ||
self.output_second = Path(self.output) / "test_2" | ||
|
||
def test_mask_structure_preserved(self): | ||
""" | ||
Checks that the mask structure is preserved across runs of oneshot | ||
between the initial pruning and the subsequent pruning + quantization | ||
""" | ||
import math | ||
|
||
import torch | ||
|
||
from sparseml.pytorch.model_load.helpers import get_session_model | ||
from sparseml.pytorch.utils.helpers import tensor_sparsity | ||
from sparseml.transformers import oneshot | ||
from sparseml.utils.pytorch import qat_active | ||
|
||
tolerance = 1e-3 | ||
num_calibration_samples = 16 | ||
|
||
oneshot( | ||
model=self.model, | ||
dataset=self.dataset, | ||
num_calibration_samples=num_calibration_samples, | ||
recipe=self.initial_pruning_only_recipe, | ||
output_dir=self.output_first, | ||
oneshot_device=self.device, | ||
clear_sparse_session=False, | ||
) | ||
first_tiny_model = get_session_model() | ||
targetted_layer = first_tiny_model.model.layers[0].self_attn.k_proj | ||
target_layer_sparsity = tensor_sparsity(targetted_layer.weight) | ||
initial_mask = first_tiny_model.model.layers[0].self_attn.k_proj.weight == 0 | ||
|
||
# sparsity is as expected, i.e close to self.initial_sparsity | ||
assert math.isclose( | ||
target_layer_sparsity.item(), self.initial_sparsity, rel_tol=tolerance | ||
) | ||
# mask structure is as expected, i.e same as self.recipe_mask_structure | ||
assert tensor_follows_mask_structure(initial_mask, self.recipe_mask_structure) | ||
|
||
sparseml.reset_session() | ||
|
||
oneshot( | ||
model=self.output_first, | ||
dataset=self.dataset, | ||
num_calibration_samples=num_calibration_samples, | ||
recipe=self.subsequent_prune_and_quant_recipe, | ||
output_dir=self.output_second, | ||
oneshot_device=self.device, | ||
clear_sparse_session=False, | ||
) | ||
|
||
second_tiny_model = get_session_model() | ||
|
||
# model is loaded | ||
assert second_tiny_model is not None | ||
|
||
targetted_layer = second_tiny_model.model.layers[0].self_attn.k_proj.module | ||
target_layer_sparsity = tensor_sparsity(targetted_layer.weight) | ||
|
||
# sparsity is as expected, i.e close to self.final_sparsity | ||
assert math.isclose( | ||
target_layer_sparsity.item(), self.final_sparsity, rel_tol=tolerance | ||
) | ||
# qat should be active, second recipe has quantization | ||
assert qat_active(second_tiny_model) | ||
|
||
# original mask structure is preserved, additional zeros are | ||
# added on top of the initial mask | ||
final_mask = targetted_layer.weight == 0 | ||
assert torch.all(initial_mask <= final_mask) |