From ce265a49ac43aafb599b1302397a3531ee590ec6 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 19 Oct 2023 10:01:47 -0400 Subject: [PATCH 01/12] - Update `src/sparseml/modifiers/obcq/pytorch.py` to use layer prefix for from model - Remove `layer_prefix` from `SparseGPTModifier` base - Update ModelMetaData to include layer_prefix - Added a convenience function to update missing values in RecipeMetaData instance from another RecipeMetaData instance - Update simplify recipe to also include metadata - Update simplify_combine_recipes to include metadata - Add layer_prefix property to `ModifiableModel` - propagate `layer_prefix` to superclass - update session.py to set_layer_prefix on the model before initializing modifiers - Update example recipe to include layer_prefix in metadata --- src/sparseml/core/lifecycle/session.py | 12 +++++++++ src/sparseml/core/model/base.py | 26 ++++++++++++++++++- src/sparseml/core/model/pytorch.py | 7 +++-- src/sparseml/core/recipe/metadata.py | 18 ++++++++++++- src/sparseml/core/recipe/recipe.py | 13 ++++++++++ src/sparseml/modifiers/obcq/base.py | 3 --- src/sparseml/modifiers/obcq/pytorch.py | 5 +++- .../sparsification/obcq/example.yaml | 8 ++++-- 8 files changed, 82 insertions(+), 10 deletions(-) diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index 62b065ed603..d784efe979a 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -84,6 +84,7 @@ def initialize(self, framework: Framework = None, **kwargs) -> List[Any]: extras = self.recipe_container.update(**extras) self._check_compile_recipe() + self._set_model_layer_prefix() mod_data = [] for mod in self.modifiers: data = mod.initialize(state=self.state, **extras) @@ -205,3 +206,14 @@ def _check_setup_event_lifecycle(self, event_type: EventType): ) else: raise ValueError(f"invalid event type {event_type}") + + def _set_model_layer_prefix(self): + if ( + (compiled_recipe := self.recipe_container.compiled_recipe) is None + or (metadata := compiled_recipe.metadata) is None + or (model_metadata := metadata.target_model) is None + ): + return False + + self.state.model.layer_prefix = model_metadata.layer_prefix + return True diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 602682004a2..c53a686afa5 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -57,13 +57,21 @@ class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject): to be searchable by the MultiFrameworkObject factory method. :param framework: the framework the model is in + :param layer_prefix: name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama :param model: the model object """ model: MT = None - def __init__(self, framework: Optional[Framework] = None, model=None): + def __init__( + self, + framework: Optional[Framework] = None, + model=None, + layer_prefix: Optional[str] = None, + ): self.model = model + self._layer_prefix = layer_prefix def get_layers_params( self, targets: Union[str, List[str]] @@ -116,3 +124,19 @@ def set_param(self, target: str, param: PT): :param param: the param instance to set """ raise NotImplementedError() + + @property + def layer_prefix(self) -> Optional[str]: + """ + :return: the name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama + """ + return self._layer_prefix + + @layer_prefix.setter + def layer_prefix(self, value: Optional[str]): + """ + :param value: the name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama + """ + self._layer_prefix = value diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index 670d164900c..7d4f9e90fcb 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -41,9 +41,12 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): """ def __init__( - self, framework: Optional[Framework] = None, model: Optional[Module] = None + self, + framework: Optional[Framework] = None, + model: Optional[Module] = None, + layer_prefix: str = "", ): - super().__init__(framework=framework, model=model) + super().__init__(framework=framework, model=model, layer_prefix=layer_prefix) def get_layers_params( self, targets: Union[str, List[str]] diff --git a/src/sparseml/core/recipe/metadata.py b/src/sparseml/core/recipe/metadata.py index 65fc907e967..c1a7ef3f991 100644 --- a/src/sparseml/core/recipe/metadata.py +++ b/src/sparseml/core/recipe/metadata.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -69,6 +69,7 @@ class ModelMetaData(BaseModel): input_shapes: List[List[int]] = None output_shapes: List[List[int]] = None layers: List[LayerMetaData] = Field(default_factory=list) + layer_prefix: Optional[str] = None class RecipeMetaData(BaseModel): @@ -79,3 +80,18 @@ class RecipeMetaData(BaseModel): tags: List[str] = None target_dataset: DatasetMetaData = None target_model: ModelMetaData = None + + def update_missing_metadata(self, other: "RecipeMetaData"): + """ + Update recipe metadata with missing values from another + recipe metadata instance + + :param other: the recipe metadata to update with + """ + self.domain = self.domain or other.domain + self.task = self.task or other.task + self.versions = self.versions or other.versions + self.requirements = self.requirements or other.requirements + self.tags = self.tags or other.tags + self.target_dataset = self.target_dataset or other.target_dataset + self.target_model = self.target_model or other.target_model diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index c46d5a0febb..11c1d244fbc 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -135,6 +135,9 @@ def simplify_recipe( simplified.args = RecipeArgs(args) simplified.stages = stages simplified.evaluate(args=args, shift=shift) + simplified.metadata = ( + recipe.metadata if isinstance(recipe, Recipe) else recipe.recipe.metadata + ) return simplified @@ -185,6 +188,7 @@ def simplify_combine_recipes( combined.version = simplified.version combined.stages.extend(simplified.stages) combined.args.update(simplified.args) + combined.combine_metadata(simplified.metadata) return combined @@ -388,6 +392,15 @@ def extract_dict_stages(values: Dict[str, Any]) -> List[Dict[str, Any]]: return stages + def combine_metadata(self, metadata: Optional[RecipeMetaData]): + if metadata is None: + return + + if self.metadata is None: + self.metadata = metadata + else: + self.metadata.update_missing_metadata(metadata) + def dict(self, *args, **kwargs) -> Dict[str, Any]: """ >>> recipe_str = ''' diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index 64c640cad61..571a94a028d 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -44,8 +44,6 @@ class SparseGPTModifier(Modifier): :param targets: list of layer names to compress during OBCQ, or '__ALL__' to compress every layer in the model :param target_ids: list of keys in model output to cache - :param layer_prefix: name of model attribute that contains the list of layers, i.e. - model.decoder for OPT or just model for Llama """ sparsity: float @@ -57,7 +55,6 @@ class SparseGPTModifier(Modifier): prunem: Optional[int] = 0 targets: Union[str, List[str], None] = ALL_TOKEN target_ids: Optional[List[str]] = None - layer_prefix: Optional[str] = None def on_initialize_structure(self, state: "State", **kwargs): pass # nothing needed for this modifier diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index a57d1484fe6..3e849eedd9a 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -49,6 +49,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier): compressible_layers_: List = None device_: str = "cuda:0" finalization_kwargs_: Dict = None + layer_prefix_: Optional[str] = None def compressible_layers(self) -> List[Module]: """ @@ -90,6 +91,7 @@ def initialize_obcq( """ self.model = model self.compressible_layers_ = self.compressible_layers() + self.layer_prefix_ = model.layer_prefix self.model = self.model.model self._set_device(device) @@ -111,7 +113,7 @@ def apply_obcq( extras = self.compress_bottom( dev=self.device_, target_ids=self.target_ids, - layer_prefix=self.layer_prefix, + layer_prefix=self.layer_prefix_, **accum_kwargs, ) accum_kwargs.update(extras) @@ -173,6 +175,7 @@ def compress_bottom( :dev: device to use :return: outputs from bottom part of network, attention mask, and kv-cache state """ + layer_prefix = layer_prefix or self.layer_prefix_ cached_inputs = cache_attention_inputs( self.model, dataloader, dev, nsamples, target_ids, layer_prefix ) diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index 619b76d7122..d2b2cc29447 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -1,3 +1,8 @@ +metadata: + target_model: + layer_prefix: "decoder" + architecture: "llama" + test_stage: obcq_modifiers: QuantizationModifier: @@ -44,5 +49,4 @@ test_stage: "model.decoder.layers.22", "model.decoder.layers.23" ] - target_ids: ["attention_mask"] - layer_prefix: "decoder" \ No newline at end of file + target_ids: ["attention_mask"] \ No newline at end of file From 9fa03749c78f218ea31a25adaee225969ab46b46 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 19 Oct 2023 10:12:59 -0400 Subject: [PATCH 02/12] Add missing docstring --- src/sparseml/core/recipe/recipe.py | 7 +++++++ src/sparseml/transformers/sparsification/obcq/example.yaml | 4 ++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 11c1d244fbc..06fca4c085c 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -393,6 +393,13 @@ def extract_dict_stages(values: Dict[str, Any]) -> List[Dict[str, Any]]: return stages def combine_metadata(self, metadata: Optional[RecipeMetaData]): + """ + Combines the metadata of the recipe with the supplied metadata + If the recipe already has metadata, the supplied metadata will + be used to update missing metadata + + :param metadata: The metadata to combine with the recipe + """ if metadata is None: return diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index d2b2cc29447..7ae666135dc 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -1,7 +1,7 @@ metadata: target_model: layer_prefix: "decoder" - architecture: "llama" + architecture: "opt" test_stage: obcq_modifiers: @@ -49,4 +49,4 @@ test_stage: "model.decoder.layers.22", "model.decoder.layers.23" ] - target_ids: ["attention_mask"] \ No newline at end of file + target_ids: ["attention_mask"] From 046b85d9b3d29eeb07bc8857bff17222173eafbb Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 25 Oct 2023 09:37:15 -0400 Subject: [PATCH 03/12] - address review comment - update docstring - add test for `update_missing_metadata` --- src/sparseml/core/model/pytorch.py | 4 +- src/sparseml/modifiers/obcq/pytorch.py | 8 +-- tests/sparseml/core/recipe/test_metadata.py | 55 +++++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) create mode 100644 tests/sparseml/core/recipe/test_metadata.py diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index 7d4f9e90fcb..e36a9c11f87 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -38,13 +38,15 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): :param framework: the framework the model is in :param model: the model object + :param layer_prefix: name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama """ def __init__( self, framework: Optional[Framework] = None, model: Optional[Module] = None, - layer_prefix: str = "", + layer_prefix: Optional[str] = None, ): super().__init__(framework=framework, model=model, layer_prefix=layer_prefix) diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 3e849eedd9a..ff7ec24c7b9 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -164,15 +164,17 @@ def compress_bottom( nsamples: int = None, dev: str = "cuda:0", target_ids: List[str] = None, - layer_prefix: str = None, + layer_prefix: Optional[str] = None, ) -> Dict: """ Runs calibration data through the bottom part of the network (everything up to the first decoder layer) and return the captured outputs :param dataloader: calibration data to pass through the model - :nsamples: number of samples to use for calibration, or None to use it all - :dev: device to use + :param nsamples: number of samples to use for calibration, or None to use it all + :param dev: device to use + :param layer_prefix: name of model attribute that contains the list of layers, + i.e. model.decoder for OPT or just model for Llama :return: outputs from bottom part of network, attention mask, and kv-cache state """ layer_prefix = layer_prefix or self.layer_prefix_ diff --git a/tests/sparseml/core/recipe/test_metadata.py b/tests/sparseml/core/recipe/test_metadata.py new file mode 100644 index 00000000000..60753f1db8e --- /dev/null +++ b/tests/sparseml/core/recipe/test_metadata.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from sparseml.core.recipe.metadata import ModelMetaData, RecipeMetaData + + +class TestRecipeMetaData: + @pytest.mark.parametrize( + "self_metadata", + [ + dict(domain="cv", task="classification"), + dict(), + ], + ) + @pytest.mark.parametrize( + "other_metadata", + [ + dict(domain="domain", task="segmentation", requirements=["torch>=1.6.0"]), + dict( + domain="cv", + task="task", + target_model=ModelMetaData(layer_prefix="something"), + ), + ], + ) + def test_update_missing_metadata(self, self_metadata, other_metadata): + + metadata_a = RecipeMetaData(**self_metadata) + metadata_b = RecipeMetaData(**other_metadata) + + metadata_a.update_missing_metadata(metadata_b) + + all_keys = set(self_metadata.keys()).union(other_metadata.keys()) + + # keys should not be overwritten + # if they already exist + for key in all_keys: + if key in self_metadata: + assert getattr(metadata_a, key) == self_metadata[key] + elif key in other_metadata: + assert getattr(metadata_a, key) == other_metadata[key] From 1bcf3a07787f401313225705d402078ad046df93 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 25 Oct 2023 10:54:55 -0400 Subject: [PATCH 04/12] Add test --- tests/sparseml/core/lifecycle/test_session.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/sparseml/core/lifecycle/test_session.py diff --git a/tests/sparseml/core/lifecycle/test_session.py b/tests/sparseml/core/lifecycle/test_session.py new file mode 100644 index 00000000000..b1775dbc98c --- /dev/null +++ b/tests/sparseml/core/lifecycle/test_session.py @@ -0,0 +1,53 @@ +import pytest +from sparseml.core.framework import Framework +import sparseml.core.session as sml + +def recipe_with_layer_prefix(): + layer_prefix = "decoder" + recipe = f""" + metadata: + target_model: + layer_prefix: {layer_prefix} + architecture: "opt" + + test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: __ALL_PRUNABLE__ + start: 0 + end: 5 + """ + return recipe, layer_prefix + +def recipe_without_layer_prefix(): + recipe = """ + metadata: + target_model: + architecture: "opt" + + test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: __ALL_PRUNABLE__ + start: 0 + end: 5 + """ + return recipe, None + +@pytest.fixture +def model(): + # identity model + return lambda x: x + + +@pytest.mark.parametrize( + "recipe, expected_layer_prefix", + [ + recipe_with_layer_prefix(), + recipe_without_layer_prefix(), # layer prefix should be none + ], +) +def test_session_initialize_propagates_layer_prefix_to_model(recipe, expected_layer_prefix, model): + session = sml.active_session() + session.initialize(framework=Framework.general ,model=model, recipe=recipe) + assert session.state.model.layer_prefix == expected_layer_prefix \ No newline at end of file From fe367d3ab4aa133e834336584d3c2e5a4e6612c9 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 25 Oct 2023 10:55:21 -0400 Subject: [PATCH 05/12] Style --- tests/sparseml/core/lifecycle/test_session.py | 30 +++++++++++++++---- tests/sparseml/core/recipe/test_metadata.py | 2 +- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/sparseml/core/lifecycle/test_session.py b/tests/sparseml/core/lifecycle/test_session.py index b1775dbc98c..bbee5727caf 100644 --- a/tests/sparseml/core/lifecycle/test_session.py +++ b/tests/sparseml/core/lifecycle/test_session.py @@ -1,6 +1,22 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest -from sparseml.core.framework import Framework + import sparseml.core.session as sml +from sparseml.core.framework import Framework + def recipe_with_layer_prefix(): layer_prefix = "decoder" @@ -19,6 +35,7 @@ def recipe_with_layer_prefix(): """ return recipe, layer_prefix + def recipe_without_layer_prefix(): recipe = """ metadata: @@ -34,6 +51,7 @@ def recipe_without_layer_prefix(): """ return recipe, None + @pytest.fixture def model(): # identity model @@ -44,10 +62,12 @@ def model(): "recipe, expected_layer_prefix", [ recipe_with_layer_prefix(), - recipe_without_layer_prefix(), # layer prefix should be none + recipe_without_layer_prefix(), # layer prefix should be none ], ) -def test_session_initialize_propagates_layer_prefix_to_model(recipe, expected_layer_prefix, model): +def test_session_initialize_propagates_layer_prefix_to_model( + recipe, expected_layer_prefix, model +): session = sml.active_session() - session.initialize(framework=Framework.general ,model=model, recipe=recipe) - assert session.state.model.layer_prefix == expected_layer_prefix \ No newline at end of file + session.initialize(framework=Framework.general, model=model, recipe=recipe) + assert session.state.model.layer_prefix == expected_layer_prefix diff --git a/tests/sparseml/core/recipe/test_metadata.py b/tests/sparseml/core/recipe/test_metadata.py index 60753f1db8e..bfc410d8467 100644 --- a/tests/sparseml/core/recipe/test_metadata.py +++ b/tests/sparseml/core/recipe/test_metadata.py @@ -46,7 +46,7 @@ def test_update_missing_metadata(self, self_metadata, other_metadata): all_keys = set(self_metadata.keys()).union(other_metadata.keys()) - # keys should not be overwritten + # keys should not be overwritten # if they already exist for key in all_keys: if key in self_metadata: From c45c745d1c2cdffc1ff76cf29bcda70eaceaa6bf Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 25 Oct 2023 11:22:27 -0400 Subject: [PATCH 06/12] Fix tests --- tests/sparseml/core/lifecycle/test_session.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/sparseml/core/lifecycle/test_session.py b/tests/sparseml/core/lifecycle/test_session.py index bbee5727caf..d9e409b94d4 100644 --- a/tests/sparseml/core/lifecycle/test_session.py +++ b/tests/sparseml/core/lifecycle/test_session.py @@ -38,10 +38,6 @@ def recipe_with_layer_prefix(): def recipe_without_layer_prefix(): recipe = """ - metadata: - target_model: - architecture: "opt" - test_stage: pruning_modifiers: ConstantPruningModifier: @@ -61,13 +57,14 @@ def model(): @pytest.mark.parametrize( "recipe, expected_layer_prefix", [ + recipe_without_layer_prefix(), recipe_with_layer_prefix(), - recipe_without_layer_prefix(), # layer prefix should be none ], ) def test_session_initialize_propagates_layer_prefix_to_model( recipe, expected_layer_prefix, model -): +): session = sml.active_session() session.initialize(framework=Framework.general, model=model, recipe=recipe) + print(f"{session.state.model.layer_prefix=}, {expected_layer_prefix=}") assert session.state.model.layer_prefix == expected_layer_prefix From 2c8f8a87d4139e456fd5ab94ec5429987e93ba80 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Wed, 25 Oct 2023 19:35:10 -0400 Subject: [PATCH 07/12] Style --- tests/sparseml/core/lifecycle/test_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sparseml/core/lifecycle/test_session.py b/tests/sparseml/core/lifecycle/test_session.py index d9e409b94d4..b0fa3687c64 100644 --- a/tests/sparseml/core/lifecycle/test_session.py +++ b/tests/sparseml/core/lifecycle/test_session.py @@ -63,7 +63,7 @@ def model(): ) def test_session_initialize_propagates_layer_prefix_to_model( recipe, expected_layer_prefix, model -): +): session = sml.active_session() session.initialize(framework=Framework.general, model=model, recipe=recipe) print(f"{session.state.model.layer_prefix=}, {expected_layer_prefix=}") From 04d05c47b1d2becd886db5ca9772fb4f38a03b89 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 27 Oct 2023 12:32:22 -0400 Subject: [PATCH 08/12] [modifier refactor] Add constant pruning tests (#1752) * Initial commit * Add end to end tests * Add e2e tests for constant pruning modifier * Move imports inside the test fuctions so that torch isn't imported unless running the tests * Update setup.py to not run modifier tests unless pytorch is specified * [Bugfix] .dict() method on Recipe (#1753) * Bugfix .dict() method on Recipe * Remove extraneous local test, [faulty commit] * [modifier refactor] Add serialization tests (#1755) * Add serialization tests * Clean up * Keep original stage and group names Clean up _get_yaml_dict * fix comment * Typo * [Unit Tests][Modifier Refactor] (#1756) * Move valid recipes to a helper file Add tests for session.py * Increase test coverage of src/sparseml/core/session.py to 100% Run Style Add logs to .gitignore * Increase coverage of tests/sparseml/core/test_state.py to 100% * add tests for lifecycle/event.py * Increase code coverage of lifecycle/event to 100% * increase lifecycle/session.py code coverage to 93% * Address review comments from @Satrat * Address review comments on 1752 (#1772) Update makefile to only ignore *pytorch.py files in modifier dir Fix order in test Add regex to makefile Add helper function to determine if torch tests should be run Check masks Make transformers import optional in sparsegpt.py * Fix merge conflict * Add more tests to check valid modifiers are created (#1774) * [Bug][ConstantPruningModifier] Fix mask de register bug (#1773) * Fix mask de-register logic * forgot to remove commented out line * Move tests inside pytorch directory as requested * Fix session reset (#1790) --- .gitignore | 3 + src/sparseml/core/lifecycle/event.py | 2 +- src/sparseml/core/lifecycle/session.py | 3 + src/sparseml/core/model/base.py | 2 - src/sparseml/core/recipe/modifier.py | 2 +- src/sparseml/core/recipe/recipe.py | 64 +- src/sparseml/core/state.py | 8 + .../modifiers/obcq/utils/sparsegpt.py | 14 +- .../pruning/utils/pytorch/layer_mask.py | 7 +- tests/sparseml/conftest.py | 16 + tests/sparseml/core/lifecycle/__init__.py | 13 + tests/sparseml/core/lifecycle/test_event.py | 624 ++++++++++++++++++ tests/sparseml/core/lifecycle/test_session.py | 264 +++++++- tests/sparseml/core/recipe/test_recipe.py | 68 +- tests/sparseml/core/test_session.py | 319 +++++++++ tests/sparseml/core/test_state.py | 122 ++++ tests/sparseml/helpers.py | 82 +++ .../pytorch/modifiers/pruning/__init__.py | 13 + .../modifiers/pruning/constant/__init__.py | 13 + .../pruning/constant/test_pytorch.py | 184 ++++++ 20 files changed, 1790 insertions(+), 33 deletions(-) create mode 100644 tests/sparseml/core/lifecycle/__init__.py create mode 100644 tests/sparseml/core/lifecycle/test_event.py create mode 100644 tests/sparseml/core/test_session.py create mode 100644 tests/sparseml/core/test_state.py create mode 100644 tests/sparseml/helpers.py create mode 100644 tests/sparseml/pytorch/modifiers/pruning/__init__.py create mode 100644 tests/sparseml/pytorch/modifiers/pruning/constant/__init__.py create mode 100644 tests/sparseml/pytorch/modifiers/pruning/constant/test_pytorch.py diff --git a/.gitignore b/.gitignore index 6ef46899c87..2fc41d97b0c 100644 --- a/.gitignore +++ b/.gitignore @@ -795,3 +795,6 @@ fabric.properties *.resources test-results/ integrations/pytorch/pytorch_vision* + +# local log files +nm_temp_test_logs/* diff --git a/src/sparseml/core/lifecycle/event.py b/src/sparseml/core/lifecycle/event.py index 5c931a679a6..71b535691c1 100644 --- a/src/sparseml/core/lifecycle/event.py +++ b/src/sparseml/core/lifecycle/event.py @@ -198,7 +198,7 @@ def optim_pre_step_events(self) -> List[Event]: and self.type_ is not None and self.type_ != EventType.OPTIM_POST_STEP ): - raise ValueError("optim pre step must be called after optim post step") + raise ValueError("optim pre step must be called before optim post step") if ( self.type_first == EventType.LOSS_CALCULATED diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index d784efe979a..25a665637f6 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -50,6 +50,9 @@ def reset(self): except Exception: pass + if self.state and self.state.data: + # reset data if it exists + self.state.data.reset() self.state = None self.recipe_container = RecipeContainer() self.modifiers = [] diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 9f8f370cab5..7d83ea2aab6 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -125,7 +125,6 @@ def set_param(self, target: str, param: PT): """ raise NotImplementedError() - @property def layer_prefix(self) -> Optional[str]: """ @@ -149,4 +148,3 @@ def qat_active(self) -> bool: :return: True if QAT is active in any layer, False otherwise """ raise NotImplementedError() - diff --git a/src/sparseml/core/recipe/modifier.py b/src/sparseml/core/recipe/modifier.py index aee85e2c8c3..40b49022bed 100644 --- a/src/sparseml/core/recipe/modifier.py +++ b/src/sparseml/core/recipe/modifier.py @@ -113,4 +113,4 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: """ :return: the dictionary representation of the modifier """ - return {self.type: self.args} + return {self.type: self.args, "group": f"{self.group}_modifiers"} diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index 06fca4c085c..c0b1dcaa240 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -419,10 +419,12 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: ... targets: ['re:.*weight'] ... ''' >>> recipe = Recipe.create_instance(recipe_str) - >>> recipe.dict() - Traceback (most recent call last): - ... - KeyError: 'group' + >>> recipe_dict = recipe.dict() + >>> stage = recipe_dict["stages"]["test"] + >>> pruning_mods = stage[0]['modifiers']['pruning'] + >>> modifier_args = pruning_mods[0]['ConstantPruningModifier'] + >>> modifier_args == {'start': 0.0, 'end': 2.0, 'targets': ['re:.*weight']} + True :return: A dictionary representation of the recipe """ @@ -430,7 +432,7 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: stages = {} for stage in dict_["stages"]: - name = stage["group"] + name = f"{stage['group']}_stage" del stage["group"] if name not in stages: @@ -442,6 +444,58 @@ def dict(self, *args, **kwargs) -> Dict[str, Any]: return dict_ + def yaml(self, file_path: Optional[str] = None) -> str: + """ + Return a yaml string representation of the recipe. + + :param file_path: optional file path to save yaml to + :return: The yaml string representation of the recipe + """ + file_stream = None if file_path is None else open(file_path, "w") + yaml_dict = self._get_yaml_dict() + + ret = yaml.dump( + yaml_dict, stream=file_stream, allow_unicode=True, sort_keys=False + ) + + if file_stream is not None: + file_stream.close() + + return ret + + def _get_yaml_dict(self) -> Dict[str, Any]: + """ + Get a dictionary representation of the recipe for yaml serialization + The returned dict will only contain information necessary for yaml + serialization (ignores metadata, version, etc), and must not be used + in place of the dict method + + :return: A dictionary representation of the recipe for yaml serialization + """ + + def _modifier_group_to_dict(modifier_group: List[Dict[str, Any]]): + # convert a list of modifiers to a dict of modifiers + return { + key: value + for modifier in modifier_group + for key, value in modifier.items() + } + + def _stage_to_dict(stage: List[Dict[str, Any]]): + # convert a list of stages to a dict of modifiers + return { + modifier_group_name: _modifier_group_to_dict(modifier_group) + for stage_modifiers in stage + for modifier_group_name, modifier_group in stage_modifiers[ + "modifiers" + ].items() + } + + return { + stage_name: _stage_to_dict(stage=stage) + for stage_name, stage in self.dict()["stages"].items() + } + @dataclass class RecipeTuple: diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index b5eb89060a6..db371759f24 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -45,6 +45,14 @@ class Data: test: Optional[ModifiableData] = None calib: Optional[ModifiableData] = None + def reset(self): + """ + Reset self to initial state + """ + attribs = Data().__dict__ + for attrib_name, attrib_value in attribs.items(): + setattr(self, attrib_name, attrib_value) + @dataclass class Hardware: diff --git a/src/sparseml/modifiers/obcq/utils/sparsegpt.py b/src/sparseml/modifiers/obcq/utils/sparsegpt.py index c1ee71ddd00..033443ca694 100644 --- a/src/sparseml/modifiers/obcq/utils/sparsegpt.py +++ b/src/sparseml/modifiers/obcq/utils/sparsegpt.py @@ -20,6 +20,13 @@ import torch.nn as nn +try: + import transformers +except ImportError as err: + transformers = None + transformers_err = err + + DEBUG = False _LOGGER = logging.getLogger(__name__) @@ -41,7 +48,8 @@ class SparseGPT: """ def __init__(self, layer): - import transformers + if transformers is None: + raise transformers_err self.layer = layer self.dev = self.layer.weight.device @@ -62,8 +70,6 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor): :param inp: tensor containing layer input :param out: tensor containing layer our """ - import transformers - if DEBUG: self._inp1 = inp self.out1 = out @@ -100,8 +106,6 @@ def fasterprune( :param percdamp: Amount of dampening to apply to H, as a fraction of the diagonal norm """ - import transformers - W = self.layer.weight.data.clone() if isinstance(self.layer, nn.Conv2d): W = W.flatten(1) diff --git a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py index b68a2494a4e..4dd945d2177 100644 --- a/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py +++ b/src/sparseml/modifiers/pruning/utils/pytorch/layer_mask.py @@ -134,9 +134,10 @@ def remove_mask(self, layer_param_name: str): mask_settings = self._mask_settings[layer_param_name] parameterized_layer = self._masked_layer_params[layer_param_name] - if mask_settings.persistent: - parameterized_layer.layer.unregister_buffer( - param_mask_name(parameterized_layer.param_name) + if not mask_settings.persistent: + delattr( + parameterized_layer.layer, + param_mask_name(parameterized_layer.param_name), ) del self._masked_layer_params[layer_param_name] diff --git a/tests/sparseml/conftest.py b/tests/sparseml/conftest.py index 4173d0ce591..d917109c494 100644 --- a/tests/sparseml/conftest.py +++ b/tests/sparseml/conftest.py @@ -92,3 +92,19 @@ def check_for_created_files(): f"megabytes of temp files created in temp directory during pytest run. " f"Created files: {set(end_files_temp) - set(start_files_temp)}" ) + + +@pytest.fixture(autouse=True, scope="function") +def setup_fresh_session(): + """ + setup any state tied to the execution of the given method in a + class. setup_method is invoked for every test method of a class. + """ + import sparseml.core.session as sml + + active_session = sml.active_session() + # start with a clean session for each test + active_session.reset() + yield + # reset the session after each test + active_session.reset() diff --git a/tests/sparseml/core/lifecycle/__init__.py b/tests/sparseml/core/lifecycle/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/core/lifecycle/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/core/lifecycle/test_event.py b/tests/sparseml/core/lifecycle/test_event.py new file mode 100644 index 00000000000..0f32754229d --- /dev/null +++ b/tests/sparseml/core/lifecycle/test_event.py @@ -0,0 +1,624 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from functools import partial + +import pytest + +from sparseml.core.event import Event, EventType +from sparseml.core.lifecycle.event import ( + CallbacksEventLifecycle, + EventLifecycle, + WrappedOptimEventLifecycle, +) + + +def test_event_lifecycle_abstract_class_can_not_be_instantiated(): + # tests event lifecycle abstract class can not be instantiated + # directly, without implementing the abstract methods + + with pytest.raises(TypeError): + EventLifecycle(type_first=EventType.BATCH_START, start=Event()) + + +class EventLifecycleDummyChild(EventLifecycle): + def batch_start_events(self): + return [], "batch_start_events" + + def loss_calculated_events(self): + return [], "loss_calculated_events" + + def optim_pre_step_events(self): + return [], "optim_pre_step_events" + + def optim_post_step_events(self): + return [], "optim_post_step_events" + + def batch_end_events(self): + return [], "batch_end_events" + + +def _get_event_lifecycle( + start=None, type_first=None, lifecycle_class=EventLifecycleDummyChild +): + start = start or Event() + type_first = type_first or EventType.BATCH_START + lifecycle = lifecycle_class(type_first=type_first, start=start) + return lifecycle + + +class TestEventLifecycle: + @pytest.mark.parametrize("type_first", [EventType.BATCH_START, EventType.BATCH_END]) + @pytest.mark.parametrize( + "start", + [ + Event(), + Event( + global_step=1, + global_batch=1, + steps_per_epoch=1, + batches_per_step=1, + invocations_per_step=1, + ), + ], + ) + def test_init(self, type_first, start): + lifecycle = _get_event_lifecycle(type_first=type_first, start=start) + assert lifecycle.type_first == type_first + assert lifecycle.steps_per_epoch == start.steps_per_epoch + assert lifecycle.batches_per_step == start.batches_per_step + assert lifecycle.invocations_per_step == start.invocations_per_step + assert lifecycle.global_step == start.global_step + assert lifecycle.global_batch == start.global_batch + + @pytest.mark.parametrize( + "type_, expected_func_name", + [ + (EventType.BATCH_START, "batch_start_events"), + (EventType.LOSS_CALCULATED, "loss_calculated_events"), + (EventType.OPTIM_PRE_STEP, "optim_pre_step_events"), + (EventType.OPTIM_POST_STEP, "optim_post_step_events"), + (EventType.BATCH_END, "batch_end_events"), + ], + ) + def test_events_from_type_valid(self, type_, expected_func_name): + lifecycle = _get_event_lifecycle() + events, func_name = lifecycle.events_from_type(type_) + + assert events == [] + assert func_name == expected_func_name + + def test_events_from_type_raises_value_error(self): + lifecycle = _get_event_lifecycle() + with pytest.raises(ValueError): + lifecycle.events_from_type("invalid") + + @pytest.mark.parametrize( + "kwargs, increment ,expected", + [ + ({"batches_per_step": None}, False, True), + ({"batches_per_step": 1}, False, True), + ({"batches_per_step": 3, "batch_count": 5}, False, True), + ({"batches_per_step": 4, "batch_count": 7}, False, True), + ({"batches_per_step": 4, "batch_count": 9}, False, False), + ({"batches_per_step": 4, "batch_count": 9}, True, False), + ({"batches_per_step": 4, "batch_count": 11}, True, True), + ], + ) + def test_check_step_batches_count(self, kwargs, increment, expected): + lifecycle = _get_event_lifecycle() + + for key, value in kwargs.items(): + setattr(lifecycle, key, value) + + actual = lifecycle.check_step_batches_count(increment=increment) + + if increment: + if not expected: + assert lifecycle.batch_count == kwargs["batch_count"] + 1 + else: + assert lifecycle.batch_count == 0 + + assert actual == expected + + @pytest.mark.parametrize( + "kwargs, increment ,expected", + [ + ({"invocations_per_step": None}, False, True), + ({"invocations_per_step": 1}, False, True), + ({"invocations_per_step": 3, "step_count": 5}, False, True), + ({"invocations_per_step": 4, "step_count": 7}, False, True), + ({"invocations_per_step": 4, "step_count": 9}, False, False), + ({"invocations_per_step": 4, "step_count": 9}, True, False), + ({"invocations_per_step": 4, "step_count": 11}, True, True), + ], + ) + def test_check_step_invocations_count(self, kwargs, increment, expected): + lifecycle = _get_event_lifecycle() + + for key, value in kwargs.items(): + setattr(lifecycle, key, value) + + actual = lifecycle.check_step_invocations_count(increment=increment) + + if increment: + if not expected: + assert lifecycle.step_count == kwargs["step_count"] + 1 + else: + assert lifecycle.step_count == 0 + assert actual == expected + + +class TestWrappedOptimEventLifecycle: + @pytest.mark.parametrize( + "method_name", + [ + "batch_start_events", + "batch_end_events", + ], + ) + def test_batch_start_and_batch_end_events_are_invalid(self, method_name): + # batch_start_events and batch_end_events must not be + # called on WrappedOptimEventLifecycle explicitly + # since they are auto-triggered when optim is wrapped + + lifecycle = _get_event_lifecycle(lifecycle_class=WrappedOptimEventLifecycle) + with pytest.raises(ValueError, match="batch"): + method = getattr(lifecycle, method_name) + method() + + @pytest.mark.parametrize( + "type_first", + [ + EventType.BATCH_START, + EventType.BATCH_END, + EventType.OPTIM_PRE_STEP, + EventType.OPTIM_POST_STEP, + ], + ) + def test_loss_calculated_events_with_invalid_first_event_type(self, type_first): + # type_first must be EventType.LOSS_CALCULATED to get + # loss_calculated_events on an + # WrappedOptimEventLifecycle instance + + lifecycle = _get_event_lifecycle( + type_first=type_first, lifecycle_class=WrappedOptimEventLifecycle + ) + with pytest.raises(ValueError, match="loss calculated must"): + lifecycle.loss_calculated_events() + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_START, + EventType.OPTIM_PRE_STEP, + EventType.BATCH_END, + ], + ) + def test_loss_calculated_events_with_invalid_event_type(self, type_): + # type_ must be EventType.LOSS_CALCULATED or + # EventType.OPITM_POST_STEP to get loss_calculated_events + # on an WrappedOptimEventLifecycle instance + + lifecycle = _get_event_lifecycle( + lifecycle_class=WrappedOptimEventLifecycle, + type_first=EventType.LOSS_CALCULATED, + ) + lifecycle.type_ = type_ + with pytest.raises(ValueError, match="loss calculated must"): + lifecycle.loss_calculated_events() + + @pytest.mark.parametrize("check_step_batches_count_return", [True, False]) + def test_loss_calculated_events(self, monkeypatch, check_step_batches_count_return): + lifecycle = _get_event_lifecycle( + lifecycle_class=WrappedOptimEventLifecycle, + type_first=EventType.LOSS_CALCULATED, + ) + lifecycle.type_ = EventType.LOSS_CALCULATED + + def mock_check_step_batches_count(ret=True, *args, **kwargs): + return ret + + monkeypatch.setattr( + lifecycle, + "check_step_batches_count", + partial(mock_check_step_batches_count, ret=check_step_batches_count_return), + ) + + results = lifecycle.loss_calculated_events() + + assert isinstance(results, list) and len(results) >= 2 + assert results[0].type_ == EventType.BATCH_START + assert results[1].type_ == EventType.LOSS_CALCULATED + + if not check_step_batches_count_return: + assert len(results) == 3 + assert results[2].type_ == EventType.BATCH_END + + @pytest.mark.parametrize( + "type_first, type_", + [ + (EventType.OPTIM_PRE_STEP, EventType.BATCH_START), + (EventType.OPTIM_PRE_STEP, EventType.LOSS_CALCULATED), + (EventType.LOSS_CALCULATED, EventType.BATCH_START), + (EventType.LOSS_CALCULATED, EventType.OPTIM_PRE_STEP), + (EventType.LOSS_CALCULATED, EventType.OPTIM_POST_STEP), + ], + ) + def test_optim_pre_step_events_raises_value_error_with_invalid_event_invocation( + self, type_first, type_ + ): + # optim pre step must be called before optim post step + # and loss calculated must be called after loss calculation + + lifecycle = _get_event_lifecycle( + lifecycle_class=WrappedOptimEventLifecycle, type_first=type_first + ) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="optim pre step must"): + lifecycle.optim_pre_step_events() + + @pytest.mark.parametrize( + "type_first, type_, check_step_invocations_count_return", + [ + (EventType.OPTIM_PRE_STEP, EventType.OPTIM_POST_STEP, False), + (EventType.OPTIM_PRE_STEP, EventType.OPTIM_POST_STEP, True), + (EventType.OPTIM_POST_STEP, EventType.OPTIM_POST_STEP, False), + ], + ) + def test_optim_pre_step_events( + self, type_first, type_, check_step_invocations_count_return, monkeypatch + ): + lifecycle = _get_event_lifecycle( + lifecycle_class=WrappedOptimEventLifecycle, type_first=type_first + ) + lifecycle.type_ = type_ + + def mock_check_step_invocations_count(ret=True, *args, **kwargs): + return ret + + monkeypatch.setattr( + lifecycle, + "check_step_invocations_count", + partial( + mock_check_step_invocations_count, + ret=check_step_invocations_count_return, + ), + ) + + results = lifecycle.optim_pre_step_events() + if type_first == EventType.OPTIM_PRE_STEP: + assert len(results) >= 1 + assert results[0].type_ == EventType.BATCH_START + + if check_step_invocations_count_return: + assert results[-1].type_ == EventType.OPTIM_PRE_STEP + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_START, + EventType.BATCH_END, + EventType.PRE_INIT, + ], + ) + def test_optim_post_step_events_raises_value_error_with_invalid_event_type( + self, type_ + ): + # optim post step must be called after optim pre step + + lifecycle = _get_event_lifecycle(lifecycle_class=WrappedOptimEventLifecycle) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="optim post step must"): + lifecycle.optim_post_step_events() + + @pytest.mark.parametrize( + "type_, check_step_invocations_count_return", + [ + (EventType.OPTIM_PRE_STEP, False), + (EventType.OPTIM_PRE_STEP, True), + ], + ) + def test_optim_post_step_events( + self, type_, monkeypatch, check_step_invocations_count_return + ): + lifecycle = _get_event_lifecycle(lifecycle_class=WrappedOptimEventLifecycle) + lifecycle.type_ = type_ + + def mock_check_step_invocations_count(ret=True, *args, **kwargs): + return ret + + monkeypatch.setattr( + lifecycle, + "check_step_invocations_count", + partial( + mock_check_step_invocations_count, + ret=check_step_invocations_count_return, + ), + ) + original_global_step = lifecycle.global_step + + results = lifecycle.optim_post_step_events() + + # type_ should be EventType.OPTIM_POST_STEP after + # optim_post_step_events is called + + assert lifecycle.type_ == EventType.OPTIM_POST_STEP + + # check results + + if not check_step_invocations_count_return: + assert lifecycle.global_step == original_global_step + assert len(results) == 1 + assert results[0].type_ == EventType.BATCH_END + else: + assert lifecycle.global_step == original_global_step + 1 + assert len(results) == 2 + assert results[0].type_ == EventType.OPTIM_POST_STEP + assert results[1].type_ == EventType.BATCH_END + + +class TestCallbackEventLifecycle: + @pytest.mark.parametrize( + "type_first, type_", + [ + (EventType.BATCH_END, EventType.BATCH_START), + (EventType.BATCH_START, EventType.BATCH_START), + (EventType.BATCH_START, EventType.OPTIM_POST_STEP), + ], + ) + def test_batch_start_events_raises_value_error_with_invalid_event_invocation( + self, type_first, type_ + ): + # batch start must be called first for CallbacksEventLifecycle + + # batch start must be called after batch end for + # CallbacksEventLifecycle + + lifecycle = _get_event_lifecycle( + lifecycle_class=CallbacksEventLifecycle, type_first=type_first + ) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="batch start must"): + lifecycle.batch_start_events() + + @pytest.mark.parametrize( + "type_first, type_", + [ + (EventType.BATCH_START, EventType.BATCH_END), + ], + ) + def test_batch_start_events(self, type_first, type_): + lifecycle = _get_event_lifecycle( + lifecycle_class=CallbacksEventLifecycle, type_first=type_first + ) + lifecycle.type_ = type_ + original_global_batch = lifecycle.global_batch + results = lifecycle.batch_start_events() + + # type_ should be EventType.BATCH_START after + # batch_start_events is called + assert lifecycle.type_ == EventType.BATCH_START + + # global_batch should be incremented by 1 + assert lifecycle.global_batch == original_global_batch + 1 + + assert len(results) == 1 + assert results[0].type_ == EventType.BATCH_START + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_END, + EventType.OPTIM_PRE_STEP, + ], + ) + def test_loss_calculated_event_raises_value_error_with_invalid_event_type( + self, type_ + ): + # loss calculated must be called after batch start + + lifecycle = _get_event_lifecycle(lifecycle_class=CallbacksEventLifecycle) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="loss calculated must"): + lifecycle.loss_calculated_events() + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_START, + ], + ) + def test_loss_calculated_events(self, type_): + lifecycle = _get_event_lifecycle(lifecycle_class=CallbacksEventLifecycle) + lifecycle.type_ = type_ + + results = lifecycle.loss_calculated_events() + + # type_ should be EventType.LOSS_CALCULATED after + # loss_calculated_events is called + assert lifecycle.type_ == EventType.LOSS_CALCULATED + + # check results + assert len(results) == 1 + assert results[0].type_ == EventType.LOSS_CALCULATED + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_END, + EventType.OPTIM_PRE_STEP, + EventType.OPTIM_POST_STEP, + ], + ) + def test_optim_pre_step_events_raises_value_error_with_invalid_event_type( + self, type_ + ): + # optim pre step must be called after batch start or + # loss calculation for CallbacksEventLifecycle + + lifecycle = _get_event_lifecycle( + lifecycle_class=CallbacksEventLifecycle, + ) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="optim pre step must"): + lifecycle.optim_pre_step_events() + + @pytest.mark.parametrize( + "type_, check_step_invocations_count_return", + [ + (EventType.BATCH_START, False), + (EventType.BATCH_START, True), + (EventType.LOSS_CALCULATED, False), + (EventType.LOSS_CALCULATED, True), + ], + ) + def test_optim_pre_step_events( + self, type_, check_step_invocations_count_return, monkeypatch + ): + lifecycle = _get_event_lifecycle( + lifecycle_class=CallbacksEventLifecycle, + ) + lifecycle.type_ = type_ + + def mock_check_step_invocations_count(ret=True, *args, **kwargs): + return ret + + monkeypatch.setattr( + lifecycle, + "check_step_invocations_count", + partial( + mock_check_step_invocations_count, + ret=check_step_invocations_count_return, + ), + ) + + results = lifecycle.optim_pre_step_events() + assert lifecycle.type_ == EventType.OPTIM_PRE_STEP + + if not check_step_invocations_count_return: + assert len(results) == 0 + else: + assert len(results) == 1 + assert results[0].type_ == EventType.OPTIM_PRE_STEP + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_START, + EventType.BATCH_END, + EventType.PRE_INIT, + EventType.LOSS_CALCULATED, + EventType.OPTIM_POST_STEP, + ], + ) + def test_optim_post_step_events_raises_value_error_with_invalid_event_type( + self, type_ + ): + # optim post step must be called after optim pre step + + lifecycle = _get_event_lifecycle(lifecycle_class=CallbacksEventLifecycle) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="optim post step must"): + lifecycle.optim_post_step_events() + + @pytest.mark.parametrize( + "type_, check_step_invocations_count_return", + [ + (EventType.OPTIM_PRE_STEP, False), + (EventType.OPTIM_PRE_STEP, True), + ], + ) + def test_optim_post_step_events( + self, type_, monkeypatch, check_step_invocations_count_return + ): + lifecycle = _get_event_lifecycle(lifecycle_class=CallbacksEventLifecycle) + lifecycle.type_ = type_ + + def mock_check_step_invocations_count(ret=True, *args, **kwargs): + return ret + + monkeypatch.setattr( + lifecycle, + "check_step_invocations_count", + partial( + mock_check_step_invocations_count, + ret=check_step_invocations_count_return, + ), + ) + original_global_step = lifecycle.global_step + + results = lifecycle.optim_post_step_events() + + # type_ should be EventType.OPTIM_POST_STEP after + # optim_post_step_events is called + + assert lifecycle.type_ == EventType.OPTIM_POST_STEP + + # check results + + if not check_step_invocations_count_return: + assert len(results) == 0 + assert lifecycle.global_batch == original_global_step + else: + assert lifecycle.global_step == original_global_step + 1 + assert len(results) == 1 + assert results[0].type_ == EventType.OPTIM_POST_STEP + + @pytest.mark.parametrize( + "type_", + [ + EventType.BATCH_END, + EventType.OPTIM_PRE_STEP, + EventType.PRE_INIT, + ], + ) + def test_batch_end_events_raises_value_error_with_invalid_event_type(self, type_): + # batch end must be called after batch start or optim post step + # or loss calculation for CallbacksEventLifecycle + + lifecycle = _get_event_lifecycle(lifecycle_class=CallbacksEventLifecycle) + lifecycle.type_ = type_ + + with pytest.raises(ValueError, match="batch end must"): + lifecycle.batch_end_events() + + @pytest.mark.parametrize( + "type_", + [ + EventType.OPTIM_POST_STEP, + EventType.LOSS_CALCULATED, + EventType.BATCH_START, + ], + ) + def test_batch_end_events(self, type_): + lifecycle = _get_event_lifecycle(lifecycle_class=CallbacksEventLifecycle) + lifecycle.type_ = type_ + + results = lifecycle.batch_end_events() + + # type_ should be EventType.BATCH_END after + # batch_end_events is called + assert lifecycle.type_ == EventType.BATCH_END + + # check results + assert len(results) == 1 + assert results[0].type_ == EventType.BATCH_END diff --git a/tests/sparseml/core/lifecycle/test_session.py b/tests/sparseml/core/lifecycle/test_session.py index b0fa3687c64..344a299c834 100644 --- a/tests/sparseml/core/lifecycle/test_session.py +++ b/tests/sparseml/core/lifecycle/test_session.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict +from types import SimpleNamespace + import pytest import sparseml.core.session as sml -from sparseml.core.framework import Framework +from sparseml.core import Framework +from sparseml.core.event import Event, EventType +from sparseml.core.lifecycle.event import CallbacksEventLifecycle +from sparseml.core.lifecycle.session import SparsificationLifecycle +from sparseml.core.modifier.base import ModifierInterface +from sparseml.core.state import State def recipe_with_layer_prefix(): @@ -68,3 +76,257 @@ def test_session_initialize_propagates_layer_prefix_to_model( session.initialize(framework=Framework.general, model=model, recipe=recipe) print(f"{session.state.model.layer_prefix=}, {expected_layer_prefix=}") assert session.state.model.layer_prefix == expected_layer_prefix + + +class ModifierMock(ModifierInterface): + initialized_ = False + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + self._hit_count = defaultdict(int) + + def initialized_structure(self) -> bool: + self._hit_count["initialized_structure"] += 1 + pass + + def initialized(self) -> bool: + self._hit_count["initialized"] += 1 + pass + + def finalized(self) -> bool: + self._hit_count["finalized"] += 1 + pass + + def check_initialized(self): + self._hit_count["check_initialized"] += 1 + pass + + def calculate_start(self) -> float: + self._hit_count["calculate_start"] += 1 + pass + + def calculate_end(self) -> float: + self._hit_count["calculate_end"] += 1 + pass + + def pre_initialize_structure(self, state: State, **kwargs): + self._hit_count["pre_initialize_structure"] += 1 + return "pre_initialize_structure" + + def initialize(self, state: State, **kwargs): + self._hit_count["initialize"] += 1 + return "initialize" + + def finalize(self, state: State, **kwargs): + self._hit_count["finalize"] += 1 + return "finalize" + + def update_event(self, state: State, event: Event, **kwargs): + self._hit_count["update_event"] += 1 + return "update_event" + + +class StateMock: + def update(self, *args, **kwargs): + return {"dummy": "dummy"} + + +def _empty_mock(*args, **kwargs): + pass + + +class TestSparsificationLifecycle: + @pytest.mark.parametrize( + "lifecycle", + [ + SparsificationLifecycle(state=State(framework=Framework.pytorch)), + ], + ) + @pytest.mark.parametrize("modifier_initialized", [True, False]) + @pytest.mark.parametrize("modifier_finalized", [True, False]) + def test_reset( + self, lifecycle, modifier_initialized, modifier_finalized, monkeypatch + ): + monkeypatch.setattr(lifecycle, "modifiers", [ModifierMock()]) + monkeypatch.setattr(ModifierMock, "initialized_", modifier_initialized) + monkeypatch.setattr(ModifierMock, "finalized", modifier_finalized) + + lifecycle.reset() + + empty_lifecycle = SparsificationLifecycle() + assert lifecycle == empty_lifecycle + + @pytest.mark.parametrize( + "lifecycle", + [ + SparsificationLifecycle(), + ], + ) + @pytest.mark.parametrize( + "method_name", + [ + "pre_initialize_structure", + "initialize", + ], + ) + def test_lifecycle_methods_call_modifier_methods( + self, lifecycle, monkeypatch, method_name + ): + monkeypatch.setattr(lifecycle, "modifiers", [modifier_mock := ModifierMock()]) + monkeypatch.setattr(lifecycle, "_check_create_state", _empty_mock) + monkeypatch.setattr(lifecycle, "_check_compile_recipe", _empty_mock) + monkeypatch.setattr(lifecycle, "state", StateMock()) + + method = getattr(lifecycle, method_name) + results = method() + + assert modifier_mock._hit_count[method_name] == 1 + assert results == [method_name] + + if method_name == "pre_initialize_structure": + assert lifecycle.initialized_structure + else: + assert lifecycle.initialized_ + + @pytest.mark.parametrize( + "initialized_, finalized", + [ + (False, False), + (False, True), + (True, True), + ], + ) + def test_finalize_raises_value_error_if_not_initialized( + self, initialized_, finalized, monkeypatch + ): + lifecycle = SparsificationLifecycle() + lifecycle.initialized_ = initialized_ + + monkeypatch.setattr(lifecycle, "finalized", finalized) + + with pytest.raises(ValueError, match="Cannot finalize"): + lifecycle.finalize() + + def test_finalize_calls_modifier_finalize(self, monkeypatch): + lifecycle = SparsificationLifecycle() + lifecycle.initialized_ = True + lifecycle.finalized = False + + monkeypatch.setattr(lifecycle, "modifiers", [modifier_mock := ModifierMock()]) + results = lifecycle.finalize() + + # assert lifecycle is finalized + assert lifecycle.finalized + + assert modifier_mock._hit_count["finalize"] == 1 + assert results == ["finalize"] + + @pytest.mark.parametrize( + "initialized_, finalized, event_type, kwargs", + [ + (False, False, EventType.BATCH_START, {}), + (False, True, EventType.BATCH_START, {}), + (True, True, EventType.BATCH_START, {}), + (True, False, EventType.PRE_INIT, {}), + (True, False, EventType.INITIALIZE, {}), + (True, False, EventType.FINALIZE, {}), + (True, False, EventType.FINALIZE, {}), + (True, False, EventType.LOSS_CALCULATED, {}), + (True, False, EventType.LOSS_CALCULATED, {"loss": None}), + ], + ) + def test_event_raises_value_error( + self, initialized_, finalized, monkeypatch, event_type, kwargs + ): + lifecycle = SparsificationLifecycle() + lifecycle.initialized_ = initialized_ + + monkeypatch.setattr(lifecycle, "finalized", finalized) + + with pytest.raises(ValueError): + lifecycle.event(event_type=event_type, **kwargs) + + def test_event_sets_state_start_event(self, monkeypatch): + + lifecycle = SparsificationLifecycle( + state=State(framework=Framework.pytorch), + event_lifecycle=CallbacksEventLifecycle( + type_first=EventType.BATCH_START, start=Event() + ), + ) + lifecycle.initialized_ = True + lifecycle.finalized = False + + event_type = EventType.BATCH_START + event = Event() + + def events_from_type_mock(*args, **kwargs): + return [event] + + monkeypatch.setattr(lifecycle, "_check_setup_event_lifecycle", _empty_mock) + monkeypatch.setattr( + lifecycle.event_lifecycle, "events_from_type", events_from_type_mock + ) + + results = lifecycle.event(event_type=event_type) + assert lifecycle.state.start_event == event + assert lifecycle.state.last_event == event + assert lifecycle.event_called + assert results == [] + + def test_event_calls_modifier_update_event(self, monkeypatch): + lifecycle = SparsificationLifecycle( + state=State(framework=Framework.pytorch), + event_lifecycle=CallbacksEventLifecycle( + type_first=EventType.BATCH_START, start=Event() + ), + ) + lifecycle.initialized_ = True + lifecycle.finalized = False + + event_type = EventType.BATCH_START + event = Event() + + def events_from_type_mock(*args, **kwargs): + return [event] + + monkeypatch.setattr(lifecycle, "_check_setup_event_lifecycle", _empty_mock) + monkeypatch.setattr(lifecycle, "modifiers", [modifier_mock := ModifierMock()]) + monkeypatch.setattr( + lifecycle.event_lifecycle, "events_from_type", events_from_type_mock + ) + + results = lifecycle.event(event_type=event_type) + assert modifier_mock._hit_count["update_event"] == 1 + assert results == ["update_event"] + + @pytest.mark.parametrize( + "event_type", + [ + EventType.BATCH_START, + EventType.LOSS_CALCULATED, + EventType.OPTIM_PRE_STEP, + EventType.OPTIM_POST_STEP, + ], + ) + def test__check_setup_event_lifecycle(self, event_type, monkeypatch): + lifecycle = SparsificationLifecycle() + event = Event() + + class StateMock: + model = 1 + start_event = 1 + sparsification_ready = 1 + start_event = event + + recipe_container_mock = SimpleNamespace(compiled_recipe=1) + + monkeypatch.setattr(lifecycle, "state", StateMock()) + monkeypatch.setattr(lifecycle, "recipe_container", recipe_container_mock) + monkeypatch.setattr(lifecycle, "modifiers", [modifier_mock := ModifierMock()]) + + lifecycle._check_setup_event_lifecycle(event_type=event_type) + + assert modifier_mock._hit_count["check_initialized"] == 1 + assert isinstance(lifecycle.event_lifecycle, CallbacksEventLifecycle) + assert lifecycle.event_lifecycle.type_first == event_type diff --git a/tests/sparseml/core/recipe/test_recipe.py b/tests/sparseml/core/recipe/test_recipe.py index 2e3a9f5c8de..94c069a2344 100644 --- a/tests/sparseml/core/recipe/test_recipe.py +++ b/tests/sparseml/core/recipe/test_recipe.py @@ -14,30 +14,68 @@ import tempfile +import pytest import yaml +from sparseml.core.framework import Framework from sparseml.core.recipe import Recipe +from tests.sparseml.helpers import should_skip_pytorch_tests, valid_recipe_strings -def _valid_recipe(): - return """ - test_stage: - pruning_modifiers: - ConstantPruningModifier: - start: 0 - end: 5 - """ - - -def test_recipe_create_instance_accepts_valid_recipe_string(): - test_recipe = _valid_recipe() - recipe = Recipe.create_instance(test_recipe) +@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) +def test_recipe_create_instance_accepts_valid_recipe_string(recipe_str): + recipe = Recipe.create_instance(recipe_str) assert recipe is not None, "Recipe could not be created from string" -def test_recipe_create_instance_accepts_valid_recipe_file(): - content = yaml.safe_load(_valid_recipe()) +@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) +def test_recipe_create_instance_accepts_valid_recipe_file(recipe_str): + content = yaml.safe_load(recipe_str) with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml") as f: yaml.dump(content, f) recipe = Recipe.create_instance(f.name) assert recipe is not None, "Recipe could not be created from file" + + +@pytest.mark.parametrize("recipe_str", valid_recipe_strings()) +def test_serialization(recipe_str): + recipe_instance = Recipe.create_instance(recipe_str) + recipe_from_serialized = Recipe.create_instance(recipe_instance.yaml()) + + expected_dict = recipe_instance.dict() + actual_dict = recipe_from_serialized.dict() + + assert expected_dict == actual_dict + + +@pytest.mark.skipif( + should_skip_pytorch_tests(), + reason="Skipping pytorch tests either torch is not installed or " + "NM_ML_SKIP_PYTORCH_TESTS is set", +) +def test_recipe_creates_correct_modifier(): + start = 1 + end = 10 + targets = "__ALL_PRUNABLE__" + + yaml_str = f""" + test_stage: + pruning_modifiers: + ConstantPruningModifier: + start: {start} + end: {end} + targets: {targets} + """ + + recipe_instance = Recipe.create_instance(yaml_str) + + stage_modifiers = recipe_instance.create_modifier(framework=Framework.pytorch) + assert len(stage_modifiers) == 1 + assert len(modifiers := stage_modifiers[0].modifiers) == 1 + from sparseml.modifiers.pruning.constant.pytorch import ( + ConstantPruningModifierPyTorch, + ) + + assert isinstance(modifier := modifiers[0], ConstantPruningModifierPyTorch) + assert modifier.start == start + assert modifier.end == end diff --git a/tests/sparseml/core/test_session.py b/tests/sparseml/core/test_session.py new file mode 100644 index 00000000000..a70cf4520eb --- /dev/null +++ b/tests/sparseml/core/test_session.py @@ -0,0 +1,319 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import defaultdict +from functools import partial +from types import SimpleNamespace + +import pytest + +import sparseml.core.session as session_module +from sparseml.core.event import EventType +from sparseml.core.framework import Framework +from tests.sparseml.helpers import should_skip_pytorch_tests + + +class LifeCycleMock: + """ + Mock class to track lifecycle method calls + """ + + def __init__(self, model=None, optimizer=None, loss=None): + self._state = SimpleNamespace( + model=SimpleNamespace(model=model), + optimizer=SimpleNamespace(optimizer=optimizer), + loss=SimpleNamespace(loss=loss), + ) + self._hit_count = defaultdict(int) + + def _increase_hit_count(self, method_name): + self._hit_count[method_name] += 1 + + def pre_initialize_structure(self, *args, **kwargs): + self._increase_hit_count("pre_initialize_structure") + return "pre_initialize_structure" + + def initialize(self, *args, **kwargs): + self._increase_hit_count("initialize") + return "initialize" + + def finalize(self, *args, **kwargs): + self._increase_hit_count("finalize") + return "finalize" + + def event(self, *args, **kwargs): + self._increase_hit_count("event") + return "event" + + def reset(self, *args, **kwargs): + self._increase_hit_count("reset") + + @property + def state(self): + return self._state + + +def get_linear_net(): + from tests.sparseml.pytorch.helpers import LinearNet + + return LinearNet() + + +class TestSparseSession: + def test_session_has_a_sparsification_lifecycle(self, setup_active_session): + assert hasattr( + setup_active_session, "lifecycle" + ), "SparseSession does not have a lifecyle" + + lifecyle = setup_active_session.lifecycle + assert isinstance( + lifecyle, session_module.SparsificationLifecycle + ), "SparseSession.lifecycle is not a SparsificationLifecycle" + + @pytest.mark.skipif( + should_skip_pytorch_tests(), + reason="Skipping pytorch tests either torch is not installed or " + "NM_ML_SKIP_PYTORCH_TESTS is set", + ) + def test_initialize_can_be_called_multiple_times_to_set_state(self, setup_session): + session_module.initialize(framework=Framework.pytorch) + state = session_module.active_session().lifecycle.state + + assert state.model is None + model = get_linear_net() + session_module.initialize(model=model) + + import torch + + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + session_module.initialize(optimizer=optimizer) + + # assert model was not overwritten + assert state.model.model is model + + @pytest.mark.skipif( + should_skip_pytorch_tests(), + reason="Skipping pytorch tests either torch is not installed or " + "NM_ML_SKIP_PYTORCH_TESTS is set", + ) + @pytest.mark.parametrize( + "method_name, kwargs", + [ + ( + "pre_initialize_structure", + {"model": get_linear_net, "framework": Framework.pytorch}, + ), + ("initialize", {"framework": Framework.pytorch}), + ("finalize", {}), + ("event", {"event_type": "test"}), + ("reset", {}), + ], + ) + def test_session_methods_invoke_lifecycle_methods( + self, method_name, kwargs, monkeypatch, setup_active_session + ): + if "model" in kwargs: + kwargs["model"] = kwargs["model"]() + + monkeypatch.setattr( + setup_active_session, + "_lifecycle", + lifecycle_mock := LifeCycleMock(model=kwargs.get("model")), + ) + method = getattr(setup_active_session, method_name) + + result = method(**kwargs) + if method_name != "reset": + assert ( + result.modifier_data == method_name + ), f"{method_name} did not invoke the lifecycle method" + else: + assert ( + lifecycle_mock._hit_count[method_name] == 1 + ), f"{method_name} did not invoke the lifecycle method" + + def test_apply_calls_lifecycle_initialize_and_finalize( + self, setup_active_session, monkeypatch + ): + monkeypatch.setattr( + setup_active_session, "_lifecycle", lifecycle_mock := LifeCycleMock() + ) + setup_active_session.apply() + + # check initialize was called once + assert ( + lifecycle_mock._hit_count["initialize"] == 1 + ), "apply did not invoke the lifecycle initialize method" + + # check finalize was called once + assert ( + lifecycle_mock._hit_count["finalize"] == 1 + ), "apply did not invoke the lifecycle finalize method" + + +@pytest.mark.parametrize( + "attribute_name", + [ + "create_session", + "active_session", + "pre_initialize_structure", + "initialize", + "finalize", + "apply", + ], +) +def test_import(attribute_name): + # this test will fail if the attribute is not found + # and will serve as a reminder to update the usages + # if the attribute is renamed or removed + + assert hasattr( + session_module, attribute_name + ), f"{attribute_name} not found in sparseml.core.session" + + +@pytest.fixture +def setup_session(): + # fixture to set up a session for each test + # that uses this fixture + + session_module.create_session() + yield + + +@pytest.fixture +def setup_active_session(setup_session): + # fixture to set up an active session for each test + # that uses this fixture + yield session_module.active_session() + + +def test_active_session_returns_sparse_session(setup_active_session): + assert isinstance( + setup_active_session, session_module.SparseSession + ), "create_session did not return a SparseSession" + + +def test_active_session_without_create_session(): + actual_session = session_module.active_session() + assert actual_session + + +def test_active_session_returns_same_session_on_subsequent_calls(setup_session): + actual_session = session_module.active_session() + assert ( + actual_session is session_module.active_session() + ), "active_session did not return the same session" + + +def test_active_session_returns_created_session(setup_session): + actual_session = session_module.active_session() + assert ( + actual_session is session_module._global_session + ), "active_session did not return the created session" + + +def test_create_session_yields_new_sessions(setup_active_session): + session_a = setup_active_session + with session_module.create_session() as session_b: + assert isinstance( + session_b, type(session_a) + ), "create_session did not return the same type of session" + assert session_a is not session_b, "create_session did not return a new session" + + +@pytest.mark.parametrize("framework", [framework for framework in Framework]) +def test_initialize_returns_modified_state(framework): + result = session_module.initialize(framework=framework) + assert isinstance( + result, session_module.ModifiedState + ), "initialize did not return a ModifiedState" + + +@pytest.mark.parametrize( + "method_name", ["pre_initialize_structure", "initialize", "finalize", "apply"] +) +def test_module_methods_call_session_methods(method_name, monkeypatch): + session_mock = LifeCycleMock() + + def active_session_mock(): + return session_mock + + monkeypatch.setattr(session_module, "active_session", active_session_mock) + + method = getattr(session_module, method_name) + if method_name == "apply": + + def apply_mock(self, *args, **kwargs): + self._increase_hit_count("apply") + return "apply" + + session_mock.apply = partial(apply_mock, self=session_mock) + + result = method() + assert ( + session_mock._hit_count[method_name] == 1 + ), f"{method_name} did not invoke equivalent session method" + if result is not None: + assert ( + result == method_name + ), f"{method_name} did not return the result of the equivalent session method" + + +def active_session_event_mock(event_type, *args, **kwargs): + return event_type + + +class TestLifecycleCallbacks: + def test_callbacks(self): + assert session_module.callbacks == session_module.LifecycleCallbacks + + @pytest.mark.parametrize( + "event_type", [EventType.PRE_INIT, EventType.INITIALIZE, EventType.FINALIZE] + ) + def test_value_eror_for_non_invokable_events(self, event_type): + with pytest.raises(ValueError): + session_module.LifecycleCallbacks.event(event_type=event_type) + + @pytest.mark.parametrize( + "event_type", + [EventType.BATCH_START, EventType.BATCH_END, EventType.LOSS_CALCULATED], + ) + def test_valid_event_calls_session_event( + self, event_type, monkeypatch, setup_active_session + ): + monkeypatch.setattr(setup_active_session, "event", active_session_event_mock) + result = session_module.LifecycleCallbacks.event(event_type=event_type) + assert result == event_type, f"{event_type} did not invoke session event" + + @pytest.mark.parametrize( + "method_name, expected_event_type", + [ + ("batch_start", EventType.BATCH_START), + ("optim_pre_step", EventType.OPTIM_PRE_STEP), + ("optim_post_step", EventType.OPTIM_POST_STEP), + ("batch_end", EventType.BATCH_END), + ("loss_calculated", EventType.LOSS_CALCULATED), + ], + ) + def test_method_call_with_right_event_type( + self, method_name, expected_event_type, monkeypatch, setup_active_session + ): + monkeypatch.setattr(setup_active_session, "event", active_session_event_mock) + method = getattr(session_module.LifecycleCallbacks, method_name) + result = method() + assert ( + result == expected_event_type + ), f"{method_name} did not invoke session event" diff --git a/tests/sparseml/core/test_state.py b/tests/sparseml/core/test_state.py new file mode 100644 index 00000000000..54619389a16 --- /dev/null +++ b/tests/sparseml/core/test_state.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from sparseml.core.framework import Framework +from sparseml.core.model.base import ModifiableModel +from sparseml.core.state import State +from tests.sparseml.helpers import should_skip_pytorch_tests + + +def get_linear_net_with_device(device="cpu"): + from tests.sparseml.pytorch.helpers import LinearNet + + class LinearNetWithMockDevice(LinearNet): + def __init__(self): + super().__init__() + self.device = device + + def to(self, device): + # Do not need to actually move + # the model to the device + + # uncomment next line to actually move + # super().to(device) + self.device = device + return self + + return LinearNetWithMockDevice() + + +class TestState: + @pytest.mark.parametrize( + "kwargs, expected", + [ + ({"framework": Framework.pytorch}, False), + ({"framework": Framework.pytorch, "model": 1}, False), + ({"framework": Framework.pytorch, "model": 1, "optimizer": 1}, True), + ], + ) + def test_sparsification_ready(self, kwargs, expected): + state = State(**kwargs) + assert state.sparsification_ready == expected + + @pytest.mark.parametrize("start", [1, None]) + def test_update_with_start_sets_start_event(self, start): + state = State(framework=Framework.pytorch) + state.update(start=start) + if start is not None: + assert state.start_event is not None + assert state.start_event.current_index == start + else: + assert state.start_event is None + + @pytest.mark.parametrize( + "arg_name, arg_value", + [ + ("batches_per_step", 1), + ("batches_per_step", None), + ("steps_per_epoch", 1), + ("steps_per_epoch", None), + ], + ) + def test_update_sets_start_event(self, arg_name, arg_value): + state = State(framework=Framework.pytorch) + state.update(**{arg_name: arg_value}) + if arg_value is not None: + assert state.start_event is not None + assert getattr(state.start_event, arg_name) == arg_value + else: + assert state.start_event is None + + @pytest.mark.parametrize( + "data_arg, data_value", + [ + ("train_data", 1), + ("test_data", 1), + ("val_data", 1), + ("calib_data", 1), + ], + ) + def test_update_sets_data(self, data_arg, data_value): + state = State(framework=Framework.pytorch) + state.update(**{data_arg: data_value}) + + # remove _data suffix + data_arg_key = data_arg[:-5] + if data_value is not None: + assert getattr(state.data, data_arg_key) == data_value + + def test_update_can_set_teacher_model(self): + state = State(framework=Framework.pytorch) + state.update(teacher_model=1) + assert state.teacher_model is not None + assert isinstance(state.teacher_model, ModifiableModel) + assert state.teacher_model.model == 1 + + @pytest.mark.skipif( + should_skip_pytorch_tests(), + reason="Skipping pytorch tests either torch is not installed or " + "NM_ML_SKIP_PYTORCH_TESTS is set", + ) + def test_update_auto_moves_model_to_device(self): + state = State(framework=Framework.pytorch) + model = get_linear_net_with_device(device="cuda") + assert model.device == "cuda" + + state.update(model=model, device="cpu") + assert state.model.model == model + assert state.model.model.device == "cpu" diff --git a/tests/sparseml/helpers.py b/tests/sparseml/helpers.py new file mode 100644 index 00000000000..5ffa186f1c3 --- /dev/null +++ b/tests/sparseml/helpers.py @@ -0,0 +1,82 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + + +def valid_recipe_strings(): + return [ + """ + test_stage: + pruning_modifiers: + ConstantPruningModifier: + start: 0 + end: 5 + targets: __ALL_PRUNABLE__ + """, + """ + test_stage: + pruning_modifiers: + ConstantPruningModifier: + start: 0 + end: 5 + targets: __ALL_PRUNABLE__ + MagnitudePruningModifier: + start: 5 + end: 10 + init_sparsity: 0.1 + final_sparsity: 0.5 + targets: __ALL_PRUNABLE__ + """, + """ + test1_stage: + pruning_modifiers: + ConstantPruningModifier: + start: 0 + end: 5 + targets: __ALL_PRUNABLE__ + test2_stage: + MagnitudePruningModifier: + start: 5 + end: 10 + init_sparsity: 0.1 + final_sparsity: 0.5 + targets: __ALL_PRUNABLE__ + """, + """ + test1_stage: + constant_modifiers: + ConstantPruningModifier: + start: 0 + end: 5 + targets: __ALL_PRUNABLE__ + magnitude_modifiers: + MagnitudePruningModifier: + start: 5 + end: 10 + init_sparsity: 0.1 + final_sparsity: 0.5 + targets: __ALL_PRUNABLE__ + """, + ] + + +def should_skip_pytorch_tests(): + try: + import torch # noqa: F401 + except ImportError: + return True + + return os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False) diff --git a/tests/sparseml/pytorch/modifiers/pruning/__init__.py b/tests/sparseml/pytorch/modifiers/pruning/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/pruning/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/pruning/constant/__init__.py b/tests/sparseml/pytorch/modifiers/pruning/constant/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/pruning/constant/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/modifiers/pruning/constant/test_pytorch.py b/tests/sparseml/pytorch/modifiers/pruning/constant/test_pytorch.py new file mode 100644 index 00000000000..c2183a4158a --- /dev/null +++ b/tests/sparseml/pytorch/modifiers/pruning/constant/test_pytorch.py @@ -0,0 +1,184 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch + +from sparseml.core import State +from sparseml.core.event import Event, EventType +from sparseml.core.framework import Framework +from sparseml.modifiers.pruning.constant.pytorch import ConstantPruningModifierPyTorch +from sparseml.modifiers.pruning.utils.pytorch.layer_mask import param_mask_name +from sparseml.pytorch.utils import tensor_sparsity +from tests.sparseml.modifiers.conf import setup_modifier_factory +from tests.sparseml.pytorch.helpers import ConvNet, LinearNet + + +def _induce_sparsity(model, sparsity=0.5): + """ + Introduces sparsity to the given model by zeroing out weights + with a probability of sparsity + + :param model: the model to introduce sparsity to + :param sparsity: the probability of zeroing out a weight + :return: the model with sparsity introduced + """ + with torch.no_grad(): + for name, param in model.named_parameters(): + if "weight" in name: + param.data = param.mul_(torch.rand_like(param) > sparsity).float() + return model + + +def _make_dense(model): + """ + Makes a model dense by setting all weights to 1 + + :param model: the model to make dense + :return: the model with all dense params + """ + with torch.no_grad(): + for name, param in model.named_parameters(): + if "weight" in name: + param.data = torch.ones_like(param.data).float() + return model + + +def _test_models(): + + return [ + _induce_sparsity(LinearNet()), + _induce_sparsity(ConvNet()), + ] + + +def _test_optims(): + return [ + torch.optim.Adam, + torch.optim.SGD, + ] + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize("model", _test_models()) +@pytest.mark.parametrize("optimizer", _test_optims()) +def test_constant_pruning_modifier_e2e(model, optimizer): + expected_sparsities = { + name: tensor_sparsity(param.data) + for name, param in model.named_parameters() + if "weight" in name + } + + # init modifier with model + + state = State(framework=Framework.pytorch) + state.update( + model=model, + optimizer=optimizer(model.parameters(), lr=0.1), + start=0, + ) + modifier = ConstantPruningModifierPyTorch( + targets="__ALL_PRUNABLE__", + start=0, + end=1, + update=0.5, + ) + modifier.initialize(state) + + # check mask is added and has correct sparsity + + for _, parameterized_layer in modifier.parameterized_layers_.items(): + mask_name = param_mask_name(parameterized_layer.param_name) + mask_tensor = parameterized_layer.layer.get_buffer(mask_name) + data_tensor = parameterized_layer.param.data + # check mask and data tensors have 0 in the same places + assert torch.all(mask_tensor == (data_tensor != 0)) + + # mess up model sparsity + + model = _make_dense(model) + manipulated_sparsities = { + name: tensor_sparsity(param.data) + for name, param in model.named_parameters() + if "weight" in name + } + assert manipulated_sparsities != expected_sparsities, "Sparsity manipulation failed" + + # apply modifier + + modifier.on_update(state, event=Event(type_=EventType.OPTIM_PRE_STEP)) + modifier.on_update(state, event=Event(type_=EventType.OPTIM_POST_STEP)) + modifier.on_end(state, None) + + # copy old mask settings as finalize will remove them + # this is needed to check if a mask was persistent + + old_mask_settings = modifier._mask_settings.copy() + modifier.finalize(state) + + # check mask is removed + for layer_param_name, parameterized_layer in modifier.parameterized_layers_.items(): + mask_name = param_mask_name(parameterized_layer.param_name) + + if not old_mask_settings[layer_param_name].persistent: + assert not hasattr(parameterized_layer.layer, mask_name) + + # mask name should not be in _mask_settings or + # _masked_layer_params + assert layer_param_name not in modifier._mask_settings + assert layer_param_name not in modifier._masked_layer_params + + # sparsity should restored by ConstantPruningModifierPyTorch + + actual_sparsities = { + name: tensor_sparsity(param.data) + for name, param in model.named_parameters() + if "weight" in name + } + assert actual_sparsities == expected_sparsities, "Sparsity was not constant" + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +def test_constant_pruning_pytorch_is_registered(): + from sparseml.core.factory import ModifierFactory + from sparseml.core.framework import Framework + from sparseml.modifiers.pruning.constant.pytorch import ( + ConstantPruningModifierPyTorch, + ) + + kwargs = dict( + start_epoch=5.0, + end_epoch=15.0, + targets="__ALL_PRUNABLE__", + ) + setup_modifier_factory() + type_ = ModifierFactory.create( + type_="ConstantPruningModifier", + framework=Framework.pytorch, + allow_experimental=False, + allow_registered=True, + **kwargs, + ) + + assert isinstance( + type_, ConstantPruningModifierPyTorch + ), "PyTorch ConstantPruningModifier not registered" From da27e82cbc66b54dacc819afc7719ecfc0232fbe Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 27 Oct 2023 14:40:45 -0400 Subject: [PATCH 09/12] fix datasets version to be compatible with fsspec (#1797) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 047ae7f0faa..24ac227701d 100644 --- a/setup.py +++ b/setup.py @@ -79,7 +79,7 @@ _transformers_deps = _pytorch_deps + [ f"{'nm-transformers' if is_release else 'nm-transformers-nightly'}" f"~={version_nm_deps}", - "datasets<=2.11", + "datasets<=2.14.6", "scikit-learn", "seqeval", "einops", From 037e302d46021f3574c53137738a6572f75a7364 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Sat, 28 Oct 2023 16:56:03 -0600 Subject: [PATCH 10/12] Add kvcache config for Mistral (#1766) * Add kvcache config for Mistral * Update configs.py * Update configs.py --- .../exporters/transforms/kv_cache/configs.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/sparseml/exporters/transforms/kv_cache/configs.py b/src/sparseml/exporters/transforms/kv_cache/configs.py index ff9189b1c41..d617075e7cb 100644 --- a/src/sparseml/exporters/transforms/kv_cache/configs.py +++ b/src/sparseml/exporters/transforms/kv_cache/configs.py @@ -138,6 +138,21 @@ class Config: multiply_batch_by_num_att_heads=False, ) +# Mistral has a config/model definition "MistralForCausalLM" but is based off Llama2. +# It contains these additions to Llama2-7b: +# * Sliding Window Attention +# * GQA (Grouped Query Attention) +# * Byte-fallback BPE tokenizer +MISTRAL_CONFIG = KeyValueCacheConfig( + model_name="mistral", + additional_transforms=AdditionalTransformsLLAMA, + key_num_attention_heads="num_attention_heads", + key_num_embedding_hidden_size="hidden_size", + transpose_value_input=None, + transpose_key_input=None, + multiply_batch_by_num_att_heads=False, +) + # Reusing the CodeGen transforms because it happens to match what we need for GPTNeo additional_transforms_gpt_neo = AdditionalTransformsCodeGen @@ -160,6 +175,7 @@ def get_kv_cache_config( BLOOM_CONFIG, MPT_CONFIG, LLAMA_CONFIG, + MISTRAL_CONFIG, GPT_NEO_CONFIG, ], ) -> KeyValueCacheConfig: From 3b0b31937c3d6b1a4fef10436f93c463a7b36ed4 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Mon, 30 Oct 2023 09:19:25 -0400 Subject: [PATCH 11/12] Fix reset logic --- src/sparseml/core/lifecycle/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index 25a665637f6..7408dff1f45 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -42,7 +42,7 @@ class SparsificationLifecycle: def reset(self): for mod in self.modifiers: - if not mod.initialized_ or mod.finalized: + if not mod.initialized or mod.finalized: continue try: From 9734ef6fb666a6523eb1a3ee7df58c5d076e5849 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Tue, 31 Oct 2023 08:47:42 -0400 Subject: [PATCH 12/12] Style after resolving merge conflicts --- src/sparseml/core/model/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 71fbb3d3e24..bee11706ade 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -125,7 +125,6 @@ def set_param(self, target: str, param: PT): """ raise NotImplementedError() - @property def layer_prefix(self) -> Optional[str]: """