diff --git a/tests/sparseml/transformers/finetune/test_finetune.py b/tests/sparseml/transformers/finetune/test_finetune.py index 823ade908a4..9342e0c499b 100644 --- a/tests/sparseml/transformers/finetune/test_finetune.py +++ b/tests/sparseml/transformers/finetune/test_finetune.py @@ -98,7 +98,7 @@ def test_oneshot_and_finetune_with_tokenizer(tmp_path: Path): def test_oneshot_then_finetune(tmp_path: Path): - recipe_str = "tests/sparseml/transformers/obcq/test_tiny2.yaml" + recipe_str = "tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml" model = "Xenova/llama2.c-stories15M" device = "cuda:0" if not torch.cuda.is_available(): diff --git a/tests/sparseml/transformers/finetune/test_finetune_helpers.py b/tests/sparseml/transformers/finetune/test_finetune_helpers.py index 3fde66276d9..7a20644f715 100644 --- a/tests/sparseml/transformers/finetune/test_finetune_helpers.py +++ b/tests/sparseml/transformers/finetune/test_finetune_helpers.py @@ -26,7 +26,7 @@ def test_apply_recipe_structure(): model = AutoModelForCausalLM.from_pretrained(model_path) assert not qat_active(model) - recipe_with_quant = "tests/sparseml/transformers/obcq/quant_and_sparse.yaml" + recipe_with_quant = "tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml" apply_recipe_structure_to_model(model, recipe_with_quant, model_path) assert qat_active(model) diff --git a/tests/sparseml/transformers/obcq/obcq_configs/completion/gpu/llama_7b_quant.yaml b/tests/sparseml/transformers/obcq/obcq_configs/completion/gpu/llama_7b_quant.yaml new file mode 100644 index 00000000000..88a9535e7d9 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/completion/gpu/llama_7b_quant.yaml @@ -0,0 +1,8 @@ +cadence: "nightly" +test_type: "regression" +model: "zoo:llama2-7b-llama2_pretrain-base" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/quant.yaml" +device: "cuda:1" +num_samples: 512 +perplexity: 20 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/completion/gpu/llama_7b_quant_and_sparse.yaml b/tests/sparseml/transformers/obcq/obcq_configs/completion/gpu/llama_7b_quant_and_sparse.yaml new file mode 100644 index 00000000000..3f430569b9f --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/completion/gpu/llama_7b_quant_and_sparse.yaml @@ -0,0 +1,8 @@ +cadence: "nightly" +test_type: "regression" +model: "zoo:llama2-7b-llama2_pretrain-base" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml" +device: "cuda:0" +num_samples: 512 +perplexity: 20 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/completion/tiny_llama_quant.yaml b/tests/sparseml/transformers/obcq/obcq_configs/completion/tiny_llama_quant.yaml new file mode 100644 index 00000000000..1f327e17b96 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/completion/tiny_llama_quant.yaml @@ -0,0 +1,7 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/quant.yaml" +num_samples: 32 +perplexity: 5000 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/completion/tiny_llama_quant_and_sparse.yaml b/tests/sparseml/transformers/obcq/obcq_configs/completion/tiny_llama_quant_and_sparse.yaml new file mode 100644 index 00000000000..982679355ef --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/completion/tiny_llama_quant_and_sparse.yaml @@ -0,0 +1,7 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml" +num_samples: 32 +perplexity: 5000 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/gpu/llama_consec_runs.yaml b/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/gpu/llama_consec_runs.yaml new file mode 100644 index 00000000000..5a4fe232844 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/gpu/llama_consec_runs.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: "zoo:llama2-7b-llama2_pretrain-base" +dataset: open_platypus +first_recipe: "tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml" +second_recipe: "tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml" +device: "auto" \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/tiny_llama_consec_runs.yaml b/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/tiny_llama_consec_runs.yaml new file mode 100644 index 00000000000..e1073aae904 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/consec_runs/tiny_llama_consec_runs.yaml @@ -0,0 +1,6 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +first_recipe: "tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml" +second_recipe: "tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml" \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml b/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml new file mode 100644 index 00000000000..5bef2cae22d --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/repeat_quants/tiny_llama_repeat_quant.yaml @@ -0,0 +1,25 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +first_recipe: | + first_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + scheme_overrides: + Embedding: + input_activations: null + +second_recipe: | + second_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - Embedding \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml b/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml new file mode 100644 index 00000000000..1b7cab983f4 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/separate_quants/tiny_llama_separate_quant.yaml @@ -0,0 +1,32 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +first_recipe: | + first_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - Linear + scheme_overrides: + Embedding: + input_activations: null +second_recipe: | + second_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLU + - Embedding + - MatMulLeftInput_QK + - MatMulRightInput_QK + - MatMulOutput_QK + - MatMulLeftInput_PV + - MatMulRightInput_PV + - MatMulOutput_PV + - QuantizableMatMul \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu/llama_7b_sparse.yaml b/tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu/llama_7b_sparse.yaml new file mode 100644 index 00000000000..1316258fbda --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu/llama_7b_sparse.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: "zoo:llama2-7b-llama2_pretrain-base" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/sparse.yaml" +sparsity: 0.3 +device: "cuda:0" \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu/llama_7b_sparse_auto.yaml b/tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu/llama_7b_sparse_auto.yaml new file mode 100644 index 00000000000..f019dbc4212 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu/llama_7b_sparse_auto.yaml @@ -0,0 +1,7 @@ +cadence: "nightly" +test_type: "regression" +model: "zoo:llama2-7b-llama2_pretrain-base" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/sparse.yaml" +sparsity: 0.3 +device: "auto" \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/obcq_configs/sparse/tiny_llama_sparse.yaml b/tests/sparseml/transformers/obcq/obcq_configs/sparse/tiny_llama_sparse.yaml new file mode 100644 index 00000000000..ffc16506405 --- /dev/null +++ b/tests/sparseml/transformers/obcq/obcq_configs/sparse/tiny_llama_sparse.yaml @@ -0,0 +1,6 @@ +cadence: "commit" +test_type: "sanity" +model: "Xenova/llama2.c-stories15M" +dataset: open_platypus +recipe: "tests/sparseml/transformers/obcq/recipes/sparse.yaml" +sparsity: 0.3 \ No newline at end of file diff --git a/tests/sparseml/transformers/obcq/additional_sparsity.yaml b/tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml similarity index 100% rename from tests/sparseml/transformers/obcq/additional_sparsity.yaml rename to tests/sparseml/transformers/obcq/recipes/additional_sparsity.yaml diff --git a/tests/sparseml/transformers/obcq/quant.yaml b/tests/sparseml/transformers/obcq/recipes/quant.yaml similarity index 100% rename from tests/sparseml/transformers/obcq/quant.yaml rename to tests/sparseml/transformers/obcq/recipes/quant.yaml diff --git a/tests/sparseml/transformers/obcq/quant_and_sparse.yaml b/tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml similarity index 100% rename from tests/sparseml/transformers/obcq/quant_and_sparse.yaml rename to tests/sparseml/transformers/obcq/recipes/quant_and_sparse.yaml diff --git a/tests/sparseml/transformers/obcq/sparse.yaml b/tests/sparseml/transformers/obcq/recipes/sparse.yaml similarity index 100% rename from tests/sparseml/transformers/obcq/sparse.yaml rename to tests/sparseml/transformers/obcq/recipes/sparse.yaml diff --git a/tests/sparseml/transformers/obcq/test_tiny2.yaml b/tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml similarity index 100% rename from tests/sparseml/transformers/obcq/test_tiny2.yaml rename to tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml diff --git a/tests/sparseml/transformers/obcq/test_consecutive_runs.py b/tests/sparseml/transformers/obcq/test_consecutive_runs.py new file mode 100644 index 00000000000..7bcfc8b7efe --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_consecutive_runs.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest +from pathlib import Path + +import pytest +import yaml + +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_gpu, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/consec_runs" +GPU_CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/consec_runs/gpu" + + +class TestConsecutiveRuns(unittest.TestCase): + def _test_consecutive_runs( + self, tolerance: float, num_calibration_samples: int = 16 + ): + import math + + import sparseml.core.session as session_manager + from sparseml.pytorch.model_load.helpers import get_session_model + from sparseml.pytorch.utils.helpers import tensor_sparsity + from sparseml.transformers import oneshot + from sparseml.utils.pytorch import qat_active + + # test recipe with 50% sparsity, quantization and smoothquant + oneshot( + model=self.model, + dataset=self.dataset, + num_calibration_samples=num_calibration_samples, + recipe=self.first_recipe, + output_dir=self.output_first, + oneshot_device=self.device, + clear_sparse_session=False, + ) + first_tiny_model = get_session_model() + layer_0_sparse = tensor_sparsity( + first_tiny_model.model.layers[0].self_attn.k_proj.module.weight + ) + assert math.isclose(layer_0_sparse.item(), 0.5, rel_tol=tolerance) + assert qat_active(first_tiny_model) + + session = session_manager.active_session() + session_recipe = session.lifecycle.recipe_container.compiled_recipe + stages = [stage.group for stage in session_recipe.stages] + self.assertEqual(len(stages), 1) + session.reset() + + # reload saved model and up sparsity to 0.7 + oneshot( + model=self.output_first, + dataset=self.dataset, + num_calibration_samples=num_calibration_samples, + recipe=self.second_recipe, + output_dir=self.output_second, + oneshot_device=self.device, + clear_sparse_session=False, + ) + + second_tiny_model = get_session_model() + layer_0_sparse = tensor_sparsity( + second_tiny_model.model.layers[0].self_attn.k_proj.module.weight + ) + assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=tolerance) + assert qat_active(second_tiny_model) + + session = session_manager.active_session() + session_recipe = session.lifecycle.recipe_container.compiled_recipe + stages = [stage.group for stage in session_recipe.stages] + self.assertEqual(len(stages), 2) + + recipe_path = self.output_second / "recipe.yaml" + recipe_data = yaml.safe_load(recipe_path.read_text()) + stage_keys = recipe_data.keys() + self.assertEqual(len(stage_keys), 2) + self.assertIn("test_stage_0", stage_keys) + self.assertIn("test_stage_1", stage_keys) + + def tearDown(self): + shutil.rmtree(self.output) + + +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestConsecutiveRunsSmall(TestConsecutiveRuns): + model = None + first_recipe = None + second_recipe = None + dataset = None + + def setUp(self): + import torch + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.output = "./oneshot_output" + self.output_first = Path(self.output) / "test_1" + self.output_second = Path(self.output) / "test_2" + + def test_consecutive_runs_small(self): + self._test_consecutive_runs(tolerance=1e-3) + + +@requires_gpu +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) +class TestConsecutiveRunsGPU(TestConsecutiveRuns): + # Will be populated using the config files + model = None + first_recipe = None + second_recipe = None + dataset = None + device = None + + def setUp(self): + from sparseml.transformers import SparseAutoModelForCausalLM + + if "zoo:" in self.model: + self.model = SparseAutoModelForCausalLM.from_pretrained( + self.model, device_map=self.device + ) + + self.output = "./oneshot_output" + self.output_first = Path(self.output) / "test_1" + self.output_second = Path(self.output) / "test_2" + + def test_consecutive_runs_gpu(self): + self._test_consecutive_runs(tolerance=1e-0, num_calibration_samples=16) diff --git a/tests/sparseml/transformers/obcq/test_obcq.py b/tests/sparseml/transformers/obcq/test_obcq.py deleted file mode 100644 index 6f0f0108db2..00000000000 --- a/tests/sparseml/transformers/obcq/test_obcq.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import pytest -import torch - -from sparseml.core import ModifiableModel -from sparseml.core.framework import Framework -from sparseml.core.state import State -from sparseml.modifiers.obcq import SparseGPTModifier -from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch -from sparseml.pytorch.model_load.helpers import get_session_model -from sparseml.pytorch.utils.helpers import tensor_sparsity -from sparseml.transformers import SparseAutoModelForCausalLM, oneshot -from sparseml.transformers.sparsification.modification.modifying_llama import ( - LlamaAttentionWithQuantizableMatmuls, -) - - -@pytest.mark.parametrize( - "recipe_file_path", - [ - "tests/sparseml/transformers/obcq/sparse.yaml", - "tests/sparseml/transformers/obcq/quant.yaml", - "tests/sparseml/transformers/obcq/quant_and_sparse.yaml", - ], -) -def test_obcq_tinystories(tmp_path, recipe_file_path): - tiny_model_path = "Xenova/llama2.c-stories15M" - device = "cuda:0" if torch.cuda.is_available() else "cpu" - model = SparseAutoModelForCausalLM.from_pretrained( - tiny_model_path, device_map=device - ) - - oneshot( - model=model, - dataset="open_platypus", - oneshot_device=device, - recipe=recipe_file_path, - max_seq_length=128, - num_calibration_samples=64, - pad_to_max_length=False, - output_dir=tmp_path / "temp_output", - ) - - is_model_quantized = "quant" in recipe_file_path - # if quantization recipe has been applied to the model, - # assert that the attention modules - # (6 of them for the tested tiny llama model), - # have been swapped for LlamaAttentionWithQuantizableMatmuls - assert is_model_quantized == ( - sum( - module.__class__.__name__ - == LlamaAttentionWithQuantizableMatmuls.__name__ # noqa E501 - for module in model.modules() - ) - == 6 - ) - - -def test_lm_head_target(): - tiny_model_path = "Xenova/llama2.c-stories15M" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" - model = SparseAutoModelForCausalLM.from_pretrained(tiny_model_path) - - kwargs = { - "sparsity": 0.5, - "block_size": 128, - "quantize": False, - "targets": [ - "model.layers.0", - "model.layers.1", - "model.layers.2", - "model.layers.3", - "model.layers.4", - "model.layers.5", - ], - } - - sparsegpt_modifier_no_head = SparseGPTModifier( - framework=Framework.pytorch, **kwargs - ) - state = State(framework=Framework.pytorch) - state.update(model=model, device=device) - sparsegpt_modifier_no_head.initialize_compression(state.model) - - kwargs["targets"].append("lm_head") - sparsegpt_modifier_head = SparseGPTModifier(framework=Framework.pytorch, **kwargs) - sparsegpt_modifier_head.initialize_compression(state.model) - - # check we pick up the lm_head layer - layers_no_head = len(sparsegpt_modifier_no_head.compressible_layers_) - layers_head = len(sparsegpt_modifier_head.compressible_layers_) - assert layers_head == layers_no_head + 1 - - # check that the - - -def test_sparsities(): - tiny_model_path = "Xenova/llama2.c-stories15M" - recipe = "tests/sparseml/transformers/obcq/sparse.yaml" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" - - # test recipe with 50% sparsity, quantization and smoothquant - oneshot( - model=tiny_model_path, - dataset="open_platypus", - oneshot_device=device, - recipe=recipe, - max_seq_length=128, - num_calibration_samples=64, - pad_to_max_length=False, - clear_sparse_session=False, - ) - - model = get_session_model() - - lm_head_sparsity = tensor_sparsity(model.lm_head.weight) - assert math.isclose(lm_head_sparsity.item(), 0.3, rel_tol=1e-4) - layer_1_sparse = tensor_sparsity(model.model.layers[1].self_attn.k_proj.weight) - assert math.isclose(layer_1_sparse.item(), 0.3, rel_tol=1e-4) - layer_2_dense = tensor_sparsity(model.model.layers[2].self_attn.k_proj.weight) - assert math.isclose(layer_2_dense.item(), 0.0, rel_tol=1e-4) - - -def test_sgpt_defaults(): - kwargs = {"sparsity": 0.5} - sparsegpt_modifier_only_sparsity = SparseGPTModifier( - framework=Framework.pytorch, **kwargs - ) - assert not sparsegpt_modifier_only_sparsity.quantize - assert sparsegpt_modifier_only_sparsity.block_size == 128 - assert sparsegpt_modifier_only_sparsity.sparsity == 0.5 - - kwargs = {"quantize": True} - sparsegpt_modifier_only_quant = SparseGPTModifier( - framework=Framework.pytorch, **kwargs - ) - assert sparsegpt_modifier_only_quant.quantize - assert sparsegpt_modifier_only_quant.block_size == 128 - assert sparsegpt_modifier_only_quant.sparsity == 0.0 - - # fail if we don't pass a sparsity or enable quantization - kwargs = {} - sparsegpt_invalid = SparseGPTModifier(framework=Framework.pytorch, **kwargs) - state_test = State(framework=Framework.pytorch) - sparsegpt_invalid.initialized_structure_ = True - with pytest.raises(ValueError): - sparsegpt_invalid.on_initialize(state=state_test) - - -def test_fake_quant_wrapper(tmp_path): - from sparseml.transformers import oneshot - - model_name = "roneneldan/TinyStories-1M" - dataset_name = "open_platypus" - overwrite_output_dir = True - precision = "bfloat16" # unsupported by native FakeQuantize - oneshot_device = "cuda:0" # unsupported by native FakeQuantize - output_dir = tmp_path / "temp_output" - recipe = """ - first_stage: - quant_modifiers: - QuantizationModifier: - ignore: - - Embedding - scheme_overrides: - LayerNorm: - input_activations: null - output_activations: null - """ - num_calibration_samples = 8 - - oneshot( - model=model_name, - dataset=dataset_name, - output_dir=output_dir, - overwrite_output_dir=overwrite_output_dir, - precision=precision, - recipe=recipe, - oneshot_device=oneshot_device, - num_calibration_samples=num_calibration_samples, - ) - - -def test_infer_targets(): - model = SparseAutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M") - modifiable_model = ModifiableModel(framework=Framework.pytorch, model=model) - targets = modifiable_model.get_no_split_params() - assert len(targets) == 1 - assert targets[0] == "LlamaDecoderLayer" - - modifier = SparseGPTModifierPyTorch(sparsity=0.5) - modifier.targets = targets - modifier.model = modifiable_model - compressible_layers = modifier.compressible_layers() - - # 15M model should have 6 transformer layers - assert len(compressible_layers) == 6 diff --git a/tests/sparseml/transformers/obcq/test_obcq_completion.py b/tests/sparseml/transformers/obcq/test_obcq_completion.py new file mode 100644 index 00000000000..c929356d871 --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_obcq_completion.py @@ -0,0 +1,162 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest + +import pytest + +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_gpu, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/completion" +GPU_CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/completion/gpu" + + +class TestOBCQCompletion(unittest.TestCase): + """ + Test for oneshot for quantization and quantization + sparsity. Sparsity-only tests + can be found under `test_obcq_sparsity.py` + """ + + def labeled_dataloader(self, dataset_name, model_name): + from torch.utils.data import DataLoader + from transformers import DefaultDataCollator + + from sparseml.transformers import SparseAutoTokenizer + from sparseml.transformers.finetune.data import TextGenerationDataset + from sparseml.transformers.finetune.data.data_args import DataTrainingArguments + + tokenizer = SparseAutoTokenizer.from_pretrained(model_name) + data_args = DataTrainingArguments( + dataset=dataset_name, + max_seq_length=512, + pad_to_max_length=False, + ) + dataset_manager = TextGenerationDataset.load_from_registry( + data_args.dataset, + data_args=data_args, + split="train", + tokenizer=tokenizer, + ) + calib_dataset = dataset_manager.tokenize_and_process( + dataset_manager.get_raw_dataset() + ) + data_loader = DataLoader( + calib_dataset, batch_size=1, collate_fn=DefaultDataCollator() + ) + + return data_loader + + def _test_oneshot_completion(self, model_name: str = None): + import torch + + from sparseml.pytorch.model_load.helpers import get_session_model + from sparseml.pytorch.utils import tensors_to_device + from sparseml.transformers import oneshot + + oneshot( + model=self.model, + dataset=self.dataset, + oneshot_device=self.device, + recipe=self.recipe, + max_seq_length=512, + num_calibration_samples=self.num_samples, + pad_to_max_length=False, + output_dir=self.output, + clear_sparse_session=False, + ) + + first_tiny_model = get_session_model() + + dataset = "open_platypus" + + iter = 10 + if model_name: + dataloader = self.labeled_dataloader(dataset, model_name) + else: + dataloader = self.labeled_dataloader(dataset, self.model) + + total_new_ppl = 0.0 + for idx, sample in enumerate(dataloader): + if idx >= iter: + break + + with torch.no_grad(): + new_output = first_tiny_model( + **(tensors_to_device(sample, self.device)) + ) + new_ppl = torch.exp(new_output.loss) + total_new_ppl += new_ppl + + avg_new_ppl = total_new_ppl / iter + self.assertLess(avg_new_ppl, self.perplexity) + + def tearDown(self): + shutil.rmtree(self.output) + + +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestOBCQCompletionSmall(TestOBCQCompletion): + + model = None + dataset = None + recipe = None + sparsity = None + num_samples = None + perplexity = None + + def setUp(self): + import torch + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.output = "./oneshot_output" + + def test_obcq_completion_small(self): + self._test_oneshot_completion() + + +@requires_torch +@requires_gpu +@pytest.mark.integration +@parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) +class TestOBCQCompletionGPU(TestOBCQCompletion): + + model = None + dataset = None + recipe = None + sparsity = None + device = None + num_samples = None + perplexity = None + + def setUp(self): + from sparseml.transformers import SparseAutoModelForCausalLM + + self.model_name = None + self.output = "./oneshot_output" + + # Temporary fix as oneshot seems to not work with zoo: models + # Need to keep th model name for the perplexity calculation post oneshot + if "zoo:" in self.model: + self.model_name = self.model + self.model = SparseAutoModelForCausalLM.from_pretrained( + self.model, device_map=self.device + ) + + def test_oneshot_completion_gpu(self): + self._test_oneshot_completion(model_name=self.model_name) diff --git a/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py b/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py new file mode 100644 index 00000000000..6fafab075b7 --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_obcq_fake_quant_wrapper.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest + +import pytest + +from tests.testing_utils import requires_torch + + +@pytest.mark.integration +@requires_torch +class TestFakeQuantWrapper(unittest.TestCase): + def setUp(self): + import torch + + self.output = "./oneshot_output" + self.model = "roneneldan/TinyStories-1M" + self.dataset = "open_platypus" + self.precision = "bfloat16" # unsupported by native FakeQuantize + self.device = ( + "cuda:0" if torch.cuda.is_available() else "cpu" + ) # unsupported by native FakeQuantize + + self.recipe = """ + first_stage: + quant_modifiers: + QuantizationModifier: + ignore: + - Embedding + scheme_overrides: + LayerNorm: + input_activations: null + output_activations: null + """ + + def test_fake_quant_wrapper(self): + from sparseml.transformers import oneshot + + oneshot( + model=self.model, + dataset=self.dataset, + output_dir=self.output, + overwrite_output_dir=True, + precision=self.precision, + recipe=self.recipe, + oneshot_device=self.device, + num_calibration_samples=9, + ) + + def tearDown(self): + shutil.rmtree(self.output) diff --git a/tests/sparseml/transformers/obcq/test_obcq_infer_targets.py b/tests/sparseml/transformers/obcq/test_obcq_infer_targets.py new file mode 100644 index 00000000000..a2f6ebb9aa3 --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_obcq_infer_targets.py @@ -0,0 +1,46 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from tests.testing_utils import requires_torch + + +@pytest.mark.integration +@requires_torch +class TestInferTargets(unittest.TestCase): + def setUp(self): + from sparseml.core import ModifiableModel + from sparseml.core.framework import Framework + from sparseml.transformers import SparseAutoModelForCausalLM + + model = SparseAutoModelForCausalLM.from_pretrained("Xenova/llama2.c-stories15M") + self.modifiable_model = ModifiableModel( + framework=Framework.pytorch, model=model + ) + self.targets = self.modifiable_model.get_no_split_params() + + def test_infer_targets(self): + from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch + + self.assertEqual(len(self.targets), 1) + self.assertEqual(self.targets[0], "LlamaDecoderLayer") + + modifier = SparseGPTModifierPyTorch(sparsity=0.5) + modifier.targets = self.targets + modifier.model = self.modifiable_model + compressible_layers = modifier.compressible_layers() + self.assertEqual(len(compressible_layers), 6) diff --git a/tests/sparseml/transformers/obcq/test_obcq_lm_head.py b/tests/sparseml/transformers/obcq/test_obcq_lm_head.py new file mode 100644 index 00000000000..372496e06cd --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_obcq_lm_head.py @@ -0,0 +1,71 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from tests.testing_utils import requires_torch + + +@pytest.mark.integration +@requires_torch +class TestLMHead(unittest.TestCase): + def setUp(self): + import torch + + from sparseml.transformers import SparseAutoModelForCausalLM + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + + self.model = SparseAutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", device_map=self.device + ) + self.kwargs = { + "sparsity": 0.5, + "block_size": 128, + "quantize": False, + "targets": [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + "model.layers.4", + "model.layers.5", + ], + } + + def test_lm_head_target(self): + from sparseml.core.framework import Framework + from sparseml.core.state import State + from sparseml.modifiers.obcq import SparseGPTModifier + + sparsegpt_modifier_no_head = SparseGPTModifier( + framework=Framework.pytorch, **self.kwargs + ) + + state = State(framework=Framework.pytorch) + state.update(model=self.model, device=self.device) + sparsegpt_modifier_no_head.initialize_compression(state.model) + + self.kwargs["targets"].append("lm_head") + sparsegpt_modifier_head = SparseGPTModifier( + framework=Framework.pytorch, **self.kwargs + ) + sparsegpt_modifier_head.initialize_compression(state.model) + + # check we pick up the lm_head layer + layers_no_head = len(sparsegpt_modifier_no_head.compressible_layers_) + layers_head = len(sparsegpt_modifier_head.compressible_layers_) + self.assertEqual(layers_head, layers_no_head + 1) diff --git a/tests/sparseml/transformers/obcq/test_obcq_sparsity.py b/tests/sparseml/transformers/obcq/test_obcq_sparsity.py new file mode 100644 index 00000000000..da46a87d65a --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_obcq_sparsity.py @@ -0,0 +1,125 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import shutil +import unittest + +import pytest + +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_gpu, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/sparse" +GPU_CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/sparse/gpu" + + +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestSparsities(unittest.TestCase): + model = None + dataset = None + recipe = None + sparsity = None + + def setUp(self): + import torch + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.output = "./oneshot_output" + + def test_sparsities(self): + from sparseml.pytorch.model_load.helpers import get_session_model + from sparseml.pytorch.utils.helpers import tensor_sparsity + from sparseml.transformers import oneshot + + oneshot( + model=self.model, + dataset=self.dataset, + oneshot_device=self.device, + recipe=self.recipe, + max_seq_length=128, + num_calibration_samples=64, + pad_to_max_length=False, + clear_sparse_session=False, + output_dir=self.output, + ) + + model = get_session_model() + + lm_head_sparsity = tensor_sparsity(model.lm_head.weight) + assert math.isclose(lm_head_sparsity.item(), self.sparsity, rel_tol=1e-4) + layer_1_sparse = tensor_sparsity(model.model.layers[1].self_attn.k_proj.weight) + assert math.isclose(layer_1_sparse.item(), self.sparsity, rel_tol=1e-4) + layer_2_dense = tensor_sparsity(model.model.layers[2].self_attn.k_proj.weight) + assert math.isclose(layer_2_dense.item(), 0.0, rel_tol=1e-4) + + def tearDown(self): + shutil.rmtree(self.output) + + +@requires_gpu +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(GPU_CONFIGS_DIRECTORY)) +class TestSparsitiesGPU(unittest.TestCase): + model = None + dataset = None + recipe = None + sparsity = None + device = None + + def setUp(self): + from sparseml.transformers import SparseAutoModelForCausalLM + + self.output = "./oneshot_output" + + if "zoo:" in self.model: + self.model = SparseAutoModelForCausalLM.from_pretrained( + self.model, device_map=self.device + ) + + def test_sparsities_gpu(self): + from sparseml.pytorch.model_load.helpers import get_session_model + from sparseml.pytorch.utils.helpers import tensor_sparsity + from sparseml.transformers import oneshot + + oneshot( + model=self.model, + dataset=self.dataset, + oneshot_device=self.device, + recipe=self.recipe, + max_seq_length=128, + num_calibration_samples=64, + pad_to_max_length=False, + clear_sparse_session=False, + output_dir=self.output, + ) + + model = get_session_model() + + lm_head_sparsity = tensor_sparsity(model.lm_head.weight) + assert math.isclose(lm_head_sparsity.item(), self.sparsity, rel_tol=1e-4) + layer_1_sparse = tensor_sparsity(model.model.layers[1].self_attn.k_proj.weight) + assert math.isclose(layer_1_sparse.item(), self.sparsity, rel_tol=1e-4) + layer_2_dense = tensor_sparsity(model.model.layers[2].self_attn.k_proj.weight) + assert math.isclose(layer_2_dense.item(), 0.0, abs_tol=1e-4) + + def tearDown(self): + import torch + + shutil.rmtree(self.output) + torch.cuda.empty_cache() diff --git a/tests/sparseml/transformers/obcq/test_repeat_quant_fails.py b/tests/sparseml/transformers/obcq/test_repeat_quant_fails.py new file mode 100644 index 00000000000..96be0d721e5 --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_repeat_quant_fails.py @@ -0,0 +1,74 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest +from pathlib import Path + +import pytest + +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/repeat_quants" + + +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestRepeatQuants(unittest.TestCase): + model = None + first_recipe = None + second_recipe = None + dataset = None + + def setUp(self): + import torch + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.output = "./oneshot_output" + self.output_first = Path(self.output) / "test_1" + self.output_second = Path(self.output) / "test_2" + + def test_fail_on_repeated_quant(self): + import sparseml.core.session as session_manager + from sparseml.transformers import oneshot + + oneshot( + model=self.model, + dataset=self.dataset, + num_calibration_samples=4, + oneshot_device=self.device, + recipe=self.first_recipe, + output_dir=self.output_first, + clear_sparse_session=False, + ) + + session = session_manager.active_session() + session.reset() + + # When trying to re-quantize with the second recipe, we should error out + # to avoid nested quantizations + with pytest.raises(RuntimeError): + oneshot( + model=self.output_first, + dataset=self.dataset, + num_calibration_samples=4, + oneshot_device=self.device, + recipe=self.second_recipe, + ) + + def tearDown(self): + shutil.rmtree(self.output) diff --git a/tests/sparseml/transformers/obcq/test_repeats.py b/tests/sparseml/transformers/obcq/test_repeats.py deleted file mode 100644 index f7267ac3d4d..00000000000 --- a/tests/sparseml/transformers/obcq/test_repeats.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import pytest -import torch -import yaml - -import sparseml.core.session as session_manager -from sparseml.pytorch.model_load.helpers import get_session_model -from sparseml.pytorch.utils.helpers import tensor_sparsity -from sparseml.transformers import oneshot -from sparseml.utils.pytorch import qat_active - - -try: - from torch import quantization as torch_quantization -except Exception: - torch_quantization = None - - -def test_consecutive_runs(tmp_path): - tiny_model_path = "Xenova/llama2.c-stories15M" - first_recipe = "tests/sparseml/transformers/obcq/quant_and_sparse.yaml" - second_recipe = "tests/sparseml/transformers/obcq/additional_sparsity.yaml" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" - - # test recipe with 50% sparsity, quantization and smoothquant - oneshot( - model=tiny_model_path, - dataset="open_platypus", - num_calibration_samples=16, - recipe=first_recipe, - output_dir=tmp_path / "test1", - oneshot_device=device, - clear_sparse_session=False, - ) - first_tiny_model = get_session_model() - layer_0_sparse = tensor_sparsity( - first_tiny_model.model.layers[0].self_attn.k_proj.module.weight - ) - assert math.isclose(layer_0_sparse.item(), 0.5, rel_tol=1e-3) - assert qat_active(first_tiny_model) - - session = session_manager.active_session() - session_recipe = session.lifecycle.recipe_container.compiled_recipe - stages = [stage.group for stage in session_recipe.stages] - assert len(stages) == 1 - session.reset() - - # reload saved model and up sparsity to 0.7 - oneshot( - model=tmp_path / "test1", - dataset="open_platypus", - num_calibration_samples=16, - recipe=second_recipe, - output_dir=tmp_path / "test2", - oneshot_device=device, - clear_sparse_session=False, - ) - - second_tiny_model = get_session_model() - layer_0_sparse = tensor_sparsity( - second_tiny_model.model.layers[0].self_attn.k_proj.module.weight - ) - assert math.isclose(layer_0_sparse.item(), 0.7, rel_tol=1e-3) - assert qat_active(second_tiny_model) - - session = session_manager.active_session() - session_recipe = session.lifecycle.recipe_container.compiled_recipe - stages = [stage.group for stage in session_recipe.stages] - assert len(stages) == 2 - - recipe_path = tmp_path / "test2" / "recipe.yaml" - recipe_data = yaml.safe_load(recipe_path.read_text()) - stage_keys = recipe_data.keys() - assert len(stage_keys) == 2 - assert "test_stage_0" in stage_keys - assert "test_stage_1" in stage_keys - - -def test_fail_on_repeated_quant(tmp_path): - first_recipe_str = """ - first_stage: - quant_modifiers: - QuantizationModifier: - ignore: - - LlamaRotaryEmbedding - - LlamaRMSNorm - - SiLU - scheme_overrides: - Embedding: - input_activations: null - """ - - second_recipe_str = """ - second_stage: - quant_modifiers: - QuantizationModifier: - ignore: - - LlamaRotaryEmbedding - - LlamaRMSNorm - - SiLU - - Embedding - """ - - tiny_model_path = "Xenova/llama2.c-stories15M" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" - - oneshot( - model=tiny_model_path, - dataset="open_platypus", - num_calibration_samples=4, - oneshot_device=device, - recipe=first_recipe_str, - output_dir=tmp_path / "test", - clear_sparse_session=False, - ) - - session = session_manager.active_session() - session.reset() - - # When trying to re-quantize with the second recipe, we should error out - # to avoid nested quantizations - with pytest.raises(RuntimeError): - oneshot( - model=tmp_path / "test", - dataset="open_platypus", - num_calibration_samples=4, - oneshot_device=device, - recipe=second_recipe_str, - ) - - -def test_separate_quants_allowed(tmp_path): - first_recipe_str = """ - first_stage: - quant_modifiers: - QuantizationModifier: - ignore: - - LlamaRotaryEmbedding - - LlamaRMSNorm - - SiLU - - Linear - scheme_overrides: - Embedding: - input_activations: null - """ - - second_recipe_str = """ - second_stage: - quant_modifiers: - QuantizationModifier: - ignore: - - LlamaRotaryEmbedding - - LlamaRMSNorm - - SiLU - - Embedding - - MatMulLeftInput_QK - - MatMulRightInput_QK - - MatMulOutput_QK - - MatMulLeftInput_PV - - MatMulRightInput_PV - - MatMulOutput_PV - - QuantizableMatMul - """ - - tiny_model_path = "Xenova/llama2.c-stories15M" - device = "cuda:0" - if not torch.cuda.is_available(): - device = "cpu" - - oneshot( - model=tiny_model_path, - dataset="open_platypus", - num_calibration_samples=16, - recipe=first_recipe_str, - output_dir=tmp_path / "test1", - oneshot_device=device, - clear_sparse_session=False, - ) - # only embedding quantized after first recipe - first_model = get_session_model() - assert not isinstance( - first_model.model.layers[0].mlp.down_proj, torch_quantization.QuantWrapper - ) - assert hasattr(first_model.model.embed_tokens, "quantization_scheme") - session = session_manager.active_session() - session.reset() - - # When trying to re-quantize with the second recipe, we should error out - # to avoid nested quantizations - oneshot( - model=tmp_path / "test1", - dataset="open_platypus", - num_calibration_samples=16, - recipe=second_recipe_str, - output_dir=tmp_path / "test2", - oneshot_device=device, - clear_sparse_session=False, - ) - - second_model = get_session_model() - # linear and embeddings should be quantized now - assert isinstance( - second_model.model.layers[0].mlp.down_proj, torch_quantization.QuantWrapper - ) - assert hasattr(second_model.model.embed_tokens, "quantization_scheme") diff --git a/tests/sparseml/transformers/obcq/test_separate_quants_allowed.py b/tests/sparseml/transformers/obcq/test_separate_quants_allowed.py new file mode 100644 index 00000000000..72e3ac9f1db --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_separate_quants_allowed.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import shutil +import unittest +from pathlib import Path + +import pytest + +from parameterized import parameterized_class +from tests.testing_utils import parse_params, requires_torch + + +CONFIGS_DIRECTORY = "tests/sparseml/transformers/obcq/obcq_configs/separate_quants" + + +@requires_torch +@pytest.mark.integration +@parameterized_class(parse_params(CONFIGS_DIRECTORY)) +class TestSeparateQuants(unittest.TestCase): + model = None + first_recipe = None + second_recipe = None + dataset = None + + def setUp(self): + import torch + + self.device = "cuda:0" if torch.cuda.is_available() else "cpu" + self.output = "./oneshot_output" + self.output_first = Path(self.output) / "test_1" + self.output_second = Path(self.output) / "test_2" + + def test_fail_on_repeated_quant(self): + import sparseml.core.session as session_manager + from sparseml.pytorch.model_load.helpers import get_session_model + from sparseml.transformers import oneshot + + try: + from torch import quantization as torch_quantization + except Exception: + torch_quantization = None + + oneshot( + model=self.model, + dataset=self.dataset, + num_calibration_samples=16, + oneshot_device=self.device, + recipe=self.first_recipe, + output_dir=self.output_first, + clear_sparse_session=False, + ) + + first_model = get_session_model() + + assert not isinstance( + first_model.model.layers[0].mlp.down_proj, torch_quantization.QuantWrapper + ) + assert hasattr(first_model.model.embed_tokens, "quantization_scheme") + session = session_manager.active_session() + session.reset() + + oneshot( + model=self.output_first, + dataset=self.dataset, + num_calibration_samples=16, + oneshot_device=self.device, + recipe=self.second_recipe, + clear_sparse_session=False, + output_dir=self.output_second, + ) + + second_model = get_session_model() + # linear and embeddings should be quantized now + assert isinstance( + second_model.model.layers[0].mlp.down_proj, torch_quantization.QuantWrapper + ) + assert hasattr(second_model.model.embed_tokens, "quantization_scheme") + + def tearDown(self): + shutil.rmtree(self.output) diff --git a/tests/sparseml/transformers/obcq/test_sgpt_defaults.py b/tests/sparseml/transformers/obcq/test_sgpt_defaults.py new file mode 100644 index 00000000000..3612e91c69a --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_sgpt_defaults.py @@ -0,0 +1,52 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import pytest + +from tests.testing_utils import requires_torch + + +@pytest.mark.integration +@requires_torch +class TestSGPTDefualts(unittest.TestCase): + def test_sgpt_defaults(self): + from sparseml.core.framework import Framework + from sparseml.core.state import State + from sparseml.modifiers.obcq import SparseGPTModifier + + kwargs = {"sparsity": 0.5} + sparsegpt_modifier_only_sparsity = SparseGPTModifier( + framework=Framework.pytorch, **kwargs + ) + assert not sparsegpt_modifier_only_sparsity.quantize + self.assertEqual(sparsegpt_modifier_only_sparsity.block_size, 128) + self.assertEqual(sparsegpt_modifier_only_sparsity.sparsity, 0.5) + + kwargs = {"quantize": True} + sparsegpt_modifier_only_quant = SparseGPTModifier( + framework=Framework.pytorch, **kwargs + ) + assert sparsegpt_modifier_only_quant.quantize + self.assertEqual(sparsegpt_modifier_only_quant.block_size, 128) + self.assertEqual(sparsegpt_modifier_only_quant.sparsity, 0.0) + + # fail if we don't pass a sparsity or enable quantization + kwargs = {} + sparsegpt_invalid = SparseGPTModifier(framework=Framework.pytorch, **kwargs) + state_test = State(framework=Framework.pytorch) + sparsegpt_invalid.initialized_structure_ = True + with self.assertRaises(ValueError): + sparsegpt_invalid.on_initialize(state=state_test) diff --git a/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py index 38369617ed7..5b1ae509734 100644 --- a/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py +++ b/tests/sparseml/transformers/sparsification/test_compress_tensor_utils.py @@ -37,7 +37,7 @@ ], ) def test_sparse_model_reload(compressed, config, dtype, tmp_path): - recipe_str = "tests/sparseml/transformers/obcq/test_tiny2.yaml" + recipe_str = "tests/sparseml/transformers/obcq/recipes/test_tiny2.yaml" model_path = "Xenova/llama2.c-stories15M" device = "cuda:0" if not torch.cuda.is_available(): diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 240d8a76da6..18e421fbd5c 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -37,6 +37,9 @@ def is_torch_available(): def is_gpu_available(): + """ + Check for GPU and warn if not found + """ try: import torch # noqa: F401