diff --git a/src/sparseml/core/lifecycle/session.py b/src/sparseml/core/lifecycle/session.py index 451cbd5bfde..7408dff1f45 100644 --- a/src/sparseml/core/lifecycle/session.py +++ b/src/sparseml/core/lifecycle/session.py @@ -42,7 +42,7 @@ class SparsificationLifecycle: def reset(self): for mod in self.modifiers: - if not mod.initialized_ or mod.finalized: + if not mod.initialized or mod.finalized: continue try: @@ -87,6 +87,7 @@ def initialize(self, framework: Framework = None, **kwargs) -> List[Any]: extras = self.recipe_container.update(**extras) self._check_compile_recipe() + self._set_model_layer_prefix() mod_data = [] for mod in self.modifiers: data = mod.initialize(state=self.state, **extras) @@ -208,3 +209,14 @@ def _check_setup_event_lifecycle(self, event_type: EventType): ) else: raise ValueError(f"invalid event type {event_type}") + + def _set_model_layer_prefix(self): + if ( + (compiled_recipe := self.recipe_container.compiled_recipe) is None + or (metadata := compiled_recipe.metadata) is None + or (model_metadata := metadata.target_model) is None + ): + return False + + self.state.model.layer_prefix = model_metadata.layer_prefix + return True diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index d5647c2c22a..bee11706ade 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -57,13 +57,21 @@ class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject): to be searchable by the MultiFrameworkObject factory method. :param framework: the framework the model is in + :param layer_prefix: name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama :param model: the model object """ model: MT = None - def __init__(self, framework: Optional[Framework] = None, model=None): + def __init__( + self, + framework: Optional[Framework] = None, + model=None, + layer_prefix: Optional[str] = None, + ): self.model = model + self._layer_prefix = layer_prefix def get_layers_params( self, targets: Union[str, List[str]] @@ -117,6 +125,22 @@ def set_param(self, target: str, param: PT): """ raise NotImplementedError() + @property + def layer_prefix(self) -> Optional[str]: + """ + :return: the name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama + """ + return self._layer_prefix + + @layer_prefix.setter + def layer_prefix(self, value: Optional[str]): + """ + :param value: the name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama + """ + self._layer_prefix = value + def get_matching_layer( self, target: str, name_to_match: str, model: LT ) -> Optional[Tuple[str, LT]]: diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index 41e4da600d3..f3a5701a3fe 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -40,12 +40,17 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]): :param framework: the framework the model is in :param model: the model object + :param layer_prefix: name of model attribute that contains the list of layers, i.e. + model.decoder for OPT or just model for Llama """ def __init__( - self, framework: Optional[Framework] = None, model: Optional[Module] = None + self, + framework: Optional[Framework] = None, + model: Optional[Module] = None, + layer_prefix: Optional[str] = None, ): - super().__init__(framework=framework, model=model) + super().__init__(framework=framework, model=model, layer_prefix=layer_prefix) def get_layers_params( self, targets: Union[str, List[str]] diff --git a/src/sparseml/core/recipe/metadata.py b/src/sparseml/core/recipe/metadata.py index 65fc907e967..c1a7ef3f991 100644 --- a/src/sparseml/core/recipe/metadata.py +++ b/src/sparseml/core/recipe/metadata.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -69,6 +69,7 @@ class ModelMetaData(BaseModel): input_shapes: List[List[int]] = None output_shapes: List[List[int]] = None layers: List[LayerMetaData] = Field(default_factory=list) + layer_prefix: Optional[str] = None class RecipeMetaData(BaseModel): @@ -79,3 +80,18 @@ class RecipeMetaData(BaseModel): tags: List[str] = None target_dataset: DatasetMetaData = None target_model: ModelMetaData = None + + def update_missing_metadata(self, other: "RecipeMetaData"): + """ + Update recipe metadata with missing values from another + recipe metadata instance + + :param other: the recipe metadata to update with + """ + self.domain = self.domain or other.domain + self.task = self.task or other.task + self.versions = self.versions or other.versions + self.requirements = self.requirements or other.requirements + self.tags = self.tags or other.tags + self.target_dataset = self.target_dataset or other.target_dataset + self.target_model = self.target_model or other.target_model diff --git a/src/sparseml/core/recipe/recipe.py b/src/sparseml/core/recipe/recipe.py index af40fa5b931..c0b1dcaa240 100644 --- a/src/sparseml/core/recipe/recipe.py +++ b/src/sparseml/core/recipe/recipe.py @@ -135,6 +135,9 @@ def simplify_recipe( simplified.args = RecipeArgs(args) simplified.stages = stages simplified.evaluate(args=args, shift=shift) + simplified.metadata = ( + recipe.metadata if isinstance(recipe, Recipe) else recipe.recipe.metadata + ) return simplified @@ -185,6 +188,7 @@ def simplify_combine_recipes( combined.version = simplified.version combined.stages.extend(simplified.stages) combined.args.update(simplified.args) + combined.combine_metadata(simplified.metadata) return combined @@ -388,6 +392,22 @@ def extract_dict_stages(values: Dict[str, Any]) -> List[Dict[str, Any]]: return stages + def combine_metadata(self, metadata: Optional[RecipeMetaData]): + """ + Combines the metadata of the recipe with the supplied metadata + If the recipe already has metadata, the supplied metadata will + be used to update missing metadata + + :param metadata: The metadata to combine with the recipe + """ + if metadata is None: + return + + if self.metadata is None: + self.metadata = metadata + else: + self.metadata.update_missing_metadata(metadata) + def dict(self, *args, **kwargs) -> Dict[str, Any]: """ >>> recipe_str = ''' diff --git a/src/sparseml/modifiers/obcq/base.py b/src/sparseml/modifiers/obcq/base.py index f9e7a1c2955..cb87af2be07 100644 --- a/src/sparseml/modifiers/obcq/base.py +++ b/src/sparseml/modifiers/obcq/base.py @@ -49,8 +49,6 @@ class SparseGPTModifier(Modifier): :param targets: list of layer names to compress during OBCQ, or '__ALL__' to compress every layer in the model :param target_ids: list of keys in model output to cache - :param layer_prefix: name of model attribute that contains the list of layers, i.e. - model.decoder for OPT or just model for Llama """ sparsity: Union[float, List[float]] diff --git a/src/sparseml/modifiers/obcq/pytorch.py b/src/sparseml/modifiers/obcq/pytorch.py index 70530f29c8d..c3a138bc6ce 100644 --- a/src/sparseml/modifiers/obcq/pytorch.py +++ b/src/sparseml/modifiers/obcq/pytorch.py @@ -47,6 +47,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier): model: Any = None device_: str = "cuda:0" finalization_kwargs_: Dict = None + layer_prefix_: Optional[str] = None def on_initialize(self, state: "State", **kwargs) -> bool: """ @@ -85,6 +86,7 @@ def initialize_obcq( """ self.model = model self.compressible_layers_ = self.compressible_layers() + self.layer_prefix_ = model.layer_prefix self.model = self.model.model self._set_device(device) @@ -106,7 +108,7 @@ def apply_obcq( extras = self.compress_bottom( dev=self.device_, target_ids=self.target_ids, - layer_prefix=self.layer_prefix, + layer_prefix=self.layer_prefix_, **accum_kwargs, ) accum_kwargs.update(extras) @@ -166,17 +168,20 @@ def compress_bottom( nsamples: int = None, dev: str = "cuda:0", target_ids: List[str] = None, - layer_prefix: str = None, + layer_prefix: Optional[str] = None, ) -> Dict: """ Runs calibration data through the bottom part of the network (everything up to the first decoder layer) and return the captured outputs :param dataloader: calibration data to pass through the model - :nsamples: number of samples to use for calibration, or None to use it all - :dev: device to use + :param nsamples: number of samples to use for calibration, or None to use it all + :param dev: device to use + :param layer_prefix: name of model attribute that contains the list of layers, + i.e. model.decoder for OPT or just model for Llama :return: outputs from bottom part of network, attention mask, and kv-cache state """ + layer_prefix = layer_prefix or self.layer_prefix_ cached_inputs = cache_attention_inputs( self.model, dataloader, dev, nsamples, target_ids, layer_prefix ) diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index 6594bf39547..ddf9d6ee3ce 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -1,3 +1,8 @@ +metadata: + target_model: + layer_prefix: "decoder" + architecture: "opt" + test_stage: obcq_modifiers: SmoothQuantModifier: @@ -52,4 +57,3 @@ test_stage: "model.decoder.layers.23" ] target_ids: ["attention_mask"] - layer_prefix: "decoder" \ No newline at end of file diff --git a/tests/sparseml/core/lifecycle/test_session.py b/tests/sparseml/core/lifecycle/test_session.py index 332a7e610ec..344a299c834 100644 --- a/tests/sparseml/core/lifecycle/test_session.py +++ b/tests/sparseml/core/lifecycle/test_session.py @@ -17,6 +17,7 @@ import pytest +import sparseml.core.session as sml from sparseml.core import Framework from sparseml.core.event import Event, EventType from sparseml.core.lifecycle.event import CallbacksEventLifecycle @@ -25,6 +26,58 @@ from sparseml.core.state import State +def recipe_with_layer_prefix(): + layer_prefix = "decoder" + recipe = f""" + metadata: + target_model: + layer_prefix: {layer_prefix} + architecture: "opt" + + test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: __ALL_PRUNABLE__ + start: 0 + end: 5 + """ + return recipe, layer_prefix + + +def recipe_without_layer_prefix(): + recipe = """ + test_stage: + pruning_modifiers: + ConstantPruningModifier: + targets: __ALL_PRUNABLE__ + start: 0 + end: 5 + """ + return recipe, None + + +@pytest.fixture +def model(): + # identity model + return lambda x: x + + +@pytest.mark.parametrize( + "recipe, expected_layer_prefix", + [ + recipe_without_layer_prefix(), + recipe_with_layer_prefix(), + ], +) +def test_session_initialize_propagates_layer_prefix_to_model( + recipe, expected_layer_prefix, model +): + session = sml.active_session() + session.initialize(framework=Framework.general, model=model, recipe=recipe) + print(f"{session.state.model.layer_prefix=}, {expected_layer_prefix=}") + assert session.state.model.layer_prefix == expected_layer_prefix + + class ModifierMock(ModifierInterface): initialized_ = False diff --git a/tests/sparseml/core/recipe/test_metadata.py b/tests/sparseml/core/recipe/test_metadata.py new file mode 100644 index 00000000000..bfc410d8467 --- /dev/null +++ b/tests/sparseml/core/recipe/test_metadata.py @@ -0,0 +1,55 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +from sparseml.core.recipe.metadata import ModelMetaData, RecipeMetaData + + +class TestRecipeMetaData: + @pytest.mark.parametrize( + "self_metadata", + [ + dict(domain="cv", task="classification"), + dict(), + ], + ) + @pytest.mark.parametrize( + "other_metadata", + [ + dict(domain="domain", task="segmentation", requirements=["torch>=1.6.0"]), + dict( + domain="cv", + task="task", + target_model=ModelMetaData(layer_prefix="something"), + ), + ], + ) + def test_update_missing_metadata(self, self_metadata, other_metadata): + + metadata_a = RecipeMetaData(**self_metadata) + metadata_b = RecipeMetaData(**other_metadata) + + metadata_a.update_missing_metadata(metadata_b) + + all_keys = set(self_metadata.keys()).union(other_metadata.keys()) + + # keys should not be overwritten + # if they already exist + for key in all_keys: + if key in self_metadata: + assert getattr(metadata_a, key) == self_metadata[key] + elif key in other_metadata: + assert getattr(metadata_a, key) == other_metadata[key]