Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Update] layer prefix to be set at model level #1778

Merged
merged 16 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/sparseml/core/lifecycle/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SparsificationLifecycle:

def reset(self):
for mod in self.modifiers:
if not mod.initialized_ or mod.finalized:
if not mod.initialized or mod.finalized:
continue

try:
Expand Down Expand Up @@ -87,6 +87,7 @@ def initialize(self, framework: Framework = None, **kwargs) -> List[Any]:
extras = self.recipe_container.update(**extras)

self._check_compile_recipe()
self._set_model_layer_prefix()
mod_data = []
for mod in self.modifiers:
data = mod.initialize(state=self.state, **extras)
Expand Down Expand Up @@ -208,3 +209,14 @@ def _check_setup_event_lifecycle(self, event_type: EventType):
)
else:
raise ValueError(f"invalid event type {event_type}")

def _set_model_layer_prefix(self):
if (
(compiled_recipe := self.recipe_container.compiled_recipe) is None
or (metadata := compiled_recipe.metadata) is None
or (model_metadata := metadata.target_model) is None
):
return False

self.state.model.layer_prefix = model_metadata.layer_prefix
return True
26 changes: 25 additions & 1 deletion src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,21 @@ class ModifiableModel(Generic[MT, LT, PT], MultiFrameworkObject):
to be searchable by the MultiFrameworkObject factory method.

:param framework: the framework the model is in
:param layer_prefix: name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
:param model: the model object
"""

model: MT = None

def __init__(self, framework: Optional[Framework] = None, model=None):
def __init__(
self,
framework: Optional[Framework] = None,
model=None,
layer_prefix: Optional[str] = None,
):
self.model = model
self._layer_prefix = layer_prefix

def get_layers_params(
self, targets: Union[str, List[str]]
Expand Down Expand Up @@ -117,6 +125,22 @@ def set_param(self, target: str, param: PT):
"""
raise NotImplementedError()

@property
def layer_prefix(self) -> Optional[str]:
"""
:return: the name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""
return self._layer_prefix

@layer_prefix.setter
def layer_prefix(self, value: Optional[str]):
"""
:param value: the name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""
self._layer_prefix = value

def get_matching_layer(
self, target: str, name_to_match: str, model: LT
) -> Optional[Tuple[str, LT]]:
Expand Down
9 changes: 7 additions & 2 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,17 @@ class ModifiableModelPyTorch(ModifiableModel[Module, Module, Parameter]):

:param framework: the framework the model is in
:param model: the model object
:param layer_prefix: name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""

def __init__(
self, framework: Optional[Framework] = None, model: Optional[Module] = None
self,
framework: Optional[Framework] = None,
model: Optional[Module] = None,
layer_prefix: Optional[str] = None,
):
super().__init__(framework=framework, model=model)
super().__init__(framework=framework, model=model, layer_prefix=layer_prefix)

def get_layers_params(
self, targets: Union[str, List[str]]
Expand Down
18 changes: 17 additions & 1 deletion src/sparseml/core/recipe/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -69,6 +69,7 @@ class ModelMetaData(BaseModel):
input_shapes: List[List[int]] = None
output_shapes: List[List[int]] = None
layers: List[LayerMetaData] = Field(default_factory=list)
layer_prefix: Optional[str] = None


class RecipeMetaData(BaseModel):
Expand All @@ -79,3 +80,18 @@ class RecipeMetaData(BaseModel):
tags: List[str] = None
target_dataset: DatasetMetaData = None
target_model: ModelMetaData = None

def update_missing_metadata(self, other: "RecipeMetaData"):
rahul-tuli marked this conversation as resolved.
Show resolved Hide resolved
"""
Update recipe metadata with missing values from another
recipe metadata instance

:param other: the recipe metadata to update with
"""
self.domain = self.domain or other.domain
self.task = self.task or other.task
self.versions = self.versions or other.versions
self.requirements = self.requirements or other.requirements
self.tags = self.tags or other.tags
self.target_dataset = self.target_dataset or other.target_dataset
self.target_model = self.target_model or other.target_model
20 changes: 20 additions & 0 deletions src/sparseml/core/recipe/recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def simplify_recipe(
simplified.args = RecipeArgs(args)
simplified.stages = stages
simplified.evaluate(args=args, shift=shift)
simplified.metadata = (
recipe.metadata if isinstance(recipe, Recipe) else recipe.recipe.metadata
)

return simplified

Expand Down Expand Up @@ -185,6 +188,7 @@ def simplify_combine_recipes(
combined.version = simplified.version
combined.stages.extend(simplified.stages)
combined.args.update(simplified.args)
combined.combine_metadata(simplified.metadata)

return combined

Expand Down Expand Up @@ -388,6 +392,22 @@ def extract_dict_stages(values: Dict[str, Any]) -> List[Dict[str, Any]]:

return stages

def combine_metadata(self, metadata: Optional[RecipeMetaData]):
"""
Combines the metadata of the recipe with the supplied metadata
If the recipe already has metadata, the supplied metadata will
be used to update missing metadata

:param metadata: The metadata to combine with the recipe
"""
if metadata is None:
return

if self.metadata is None:
self.metadata = metadata
else:
self.metadata.update_missing_metadata(metadata)

def dict(self, *args, **kwargs) -> Dict[str, Any]:
"""
>>> recipe_str = '''
Expand Down
2 changes: 0 additions & 2 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ class SparseGPTModifier(Modifier):
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param target_ids: list of keys in model output to cache
:param layer_prefix: name of model attribute that contains the list of layers, i.e.
model.decoder for OPT or just model for Llama
"""

sparsity: Union[float, List[float]]
Expand Down
13 changes: 9 additions & 4 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
model: Any = None
device_: str = "cuda:0"
finalization_kwargs_: Dict = None
layer_prefix_: Optional[str] = None

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand Down Expand Up @@ -85,6 +86,7 @@ def initialize_obcq(
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.layer_prefix_ = model.layer_prefix
self.model = self.model.model
self._set_device(device)

Expand All @@ -106,7 +108,7 @@ def apply_obcq(
extras = self.compress_bottom(
dev=self.device_,
target_ids=self.target_ids,
layer_prefix=self.layer_prefix,
layer_prefix=self.layer_prefix_,
**accum_kwargs,
)
accum_kwargs.update(extras)
Expand Down Expand Up @@ -166,17 +168,20 @@ def compress_bottom(
nsamples: int = None,
dev: str = "cuda:0",
target_ids: List[str] = None,
layer_prefix: str = None,
layer_prefix: Optional[str] = None,
) -> Dict:
"""
Runs calibration data through the bottom part of the network (everything up
to the first decoder layer) and return the captured outputs

:param dataloader: calibration data to pass through the model
:nsamples: number of samples to use for calibration, or None to use it all
:dev: device to use
:param nsamples: number of samples to use for calibration, or None to use it all
:param dev: device to use
:param layer_prefix: name of model attribute that contains the list of layers,
i.e. model.decoder for OPT or just model for Llama
:return: outputs from bottom part of network, attention mask, and kv-cache state
"""
layer_prefix = layer_prefix or self.layer_prefix_
cached_inputs = cache_attention_inputs(
self.model, dataloader, dev, nsamples, target_ids, layer_prefix
)
Expand Down
6 changes: 5 additions & 1 deletion src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
metadata:
target_model:
layer_prefix: "decoder"
architecture: "opt"

test_stage:
obcq_modifiers:
SmoothQuantModifier:
Expand Down Expand Up @@ -52,4 +57,3 @@ test_stage:
"model.decoder.layers.23"
]
target_ids: ["attention_mask"]
layer_prefix: "decoder"
53 changes: 53 additions & 0 deletions tests/sparseml/core/lifecycle/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest

import sparseml.core.session as sml
from sparseml.core import Framework
from sparseml.core.event import Event, EventType
from sparseml.core.lifecycle.event import CallbacksEventLifecycle
Expand All @@ -25,6 +26,58 @@
from sparseml.core.state import State


def recipe_with_layer_prefix():
layer_prefix = "decoder"
recipe = f"""
metadata:
target_model:
layer_prefix: {layer_prefix}
architecture: "opt"

test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: __ALL_PRUNABLE__
start: 0
end: 5
"""
return recipe, layer_prefix


def recipe_without_layer_prefix():
recipe = """
test_stage:
pruning_modifiers:
ConstantPruningModifier:
targets: __ALL_PRUNABLE__
start: 0
end: 5
"""
return recipe, None


@pytest.fixture
def model():
# identity model
return lambda x: x


@pytest.mark.parametrize(
"recipe, expected_layer_prefix",
[
recipe_without_layer_prefix(),
recipe_with_layer_prefix(),
],
)
def test_session_initialize_propagates_layer_prefix_to_model(
recipe, expected_layer_prefix, model
):
session = sml.active_session()
session.initialize(framework=Framework.general, model=model, recipe=recipe)
print(f"{session.state.model.layer_prefix=}, {expected_layer_prefix=}")
assert session.state.model.layer_prefix == expected_layer_prefix


class ModifierMock(ModifierInterface):
initialized_ = False

Expand Down
55 changes: 55 additions & 0 deletions tests/sparseml/core/recipe/test_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytest

from sparseml.core.recipe.metadata import ModelMetaData, RecipeMetaData


class TestRecipeMetaData:
@pytest.mark.parametrize(
"self_metadata",
[
dict(domain="cv", task="classification"),
dict(),
],
)
@pytest.mark.parametrize(
"other_metadata",
[
dict(domain="domain", task="segmentation", requirements=["torch>=1.6.0"]),
dict(
domain="cv",
task="task",
target_model=ModelMetaData(layer_prefix="something"),
),
],
)
def test_update_missing_metadata(self, self_metadata, other_metadata):

metadata_a = RecipeMetaData(**self_metadata)
metadata_b = RecipeMetaData(**other_metadata)

metadata_a.update_missing_metadata(metadata_b)

all_keys = set(self_metadata.keys()).union(other_metadata.keys())

# keys should not be overwritten
# if they already exist
for key in all_keys:
if key in self_metadata:
assert getattr(metadata_a, key) == self_metadata[key]
elif key in other_metadata:
assert getattr(metadata_a, key) == other_metadata[key]
Loading