Skip to content

Commit

Permalink
Merge branch 'main' into attach-layer-prefix-to-the-model
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Oct 27, 2023
2 parents 2c8f8a8 + e20927f commit 19188fb
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 70 deletions.
10 changes: 10 additions & 0 deletions src/sparseml/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def set_param(self, target: str, param: PT):
"""
raise NotImplementedError()


@property
def layer_prefix(self) -> Optional[str]:
"""
Expand All @@ -140,3 +141,12 @@ def layer_prefix(self, value: Optional[str]):
model.decoder for OPT or just model for Llama
"""
self._layer_prefix = value

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
:return: True if QAT is active in any layer, False otherwise
"""
raise NotImplementedError()

9 changes: 9 additions & 0 deletions src/sparseml/core/model/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_layers_params,
get_param,
get_params,
qat_active,
set_layer,
set_param,
)
Expand Down Expand Up @@ -99,3 +100,11 @@ def set_param(self, target: str, param: Parameter):
:param param: the parameter to set
"""
return set_param(target, param, self.model)

def qat_active(self) -> bool:
"""
Checks if quantization aware training is set up in the model
:return: True if QAT is active in any layer, False otherwise
"""
return qat_active(self.model)
112 changes: 105 additions & 7 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Optional, Union
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.state import State
from sparseml.utils import ALL_TOKEN


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(Modifier):
"""
Expand All @@ -34,7 +37,9 @@ class SparseGPTModifier(Modifier):
:param sparsity: Sparsity to compress model to
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not model is quantized (affects layer names)
:param quantize: Whether or not to quantize weights during SparseGPT. Set to True
to quantize using an existing quantization modifier, or pass in the configuration
for a quantization modifier if one does not already exist in the recipe
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
:param sequential_update: Whether or not to update weights sequentially by layer,
Expand All @@ -46,15 +51,108 @@ class SparseGPTModifier(Modifier):
:param target_ids: list of keys in model output to cache
"""

sparsity: float
sparsity: Union[float, List[float]]
block_size: int
quantize: bool
quantize: Union[bool, Dict]
dampening_frac: Optional[float] = 0.01
sequential_update: Optional[bool] = True
prunen: Optional[int] = 0
prunem: Optional[int] = 0
targets: Union[str, List[str], None] = ALL_TOKEN
target_ids: Optional[List[str]] = None
layer_prefix: Optional[str] = None
compressible_layers_: List = None
quantization_modifier_: Any = None

def compressible_layers(self) -> List:
"""
Retrieves the modules corresponding to a list of compressible layer names
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_initialize_structure(self, state: State, **kwargs):
quantization_already_active = state.model.qat_active()
if isinstance(self.quantize, bool):
if not self.quantize and quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to False, but a "
"quantization modifier is already active on the model "
"resetting quantize to True"
)
self.quantize = True
elif self.quantize and not quantization_already_active:
_LOGGER.warning(
"SparseGPT quantization is set to True without an "
"active quantization modifier. Creating a default "
"8-bit quantization modifier"
)
default_quant_config = {"QuantizationModifier": {}}
self._build_quant_modifier_from_dict(
default_quant_config, state.framework
)
return # use existing quantization modifier if there is one
else:
if not isinstance(self.quantize, Dict):
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"type {type(self.quantize)}"
)
if len(self.quantize) != 1:
raise ValueError(
"SparseGPTModifier.quantize accepts only a single "
"quantization modifier or a boolean. Found "
f"{len(self.quantize)} modifiers"
)
if quantization_already_active:
_LOGGER.warning(
"Attempting to initialize quantization for SparseGPT "
"but a quantization modifier has already been applied. "
"The quantization configuration defined under the "
"SparseGPT modifier will be ignored."
)
self.quantize = True
return
self._build_quant_modifier_from_dict(self.quantize, state.framework)
self.quantize = True

if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
self.quantization_modifier_ = ModifierFactory.create(
modifier_type,
framework=framework,
allow_registered=True,
allow_experimental=True,
**modifier_args,
)

def _validate_layerwise_sparisity(self):
if isinstance(self.sparsity, float):
return # single sparsity will be applied to all layers

if not isinstance(self.targets, List):
raise ValueError(
"Layer targets must be a list when specifying layer-wise"
f" sparsity. Got {self.targets}"
)

if len(self.targets) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(self.targets)} layers and "
f"{len(self.sparsity)} sparsities"
)

def on_initialize_structure(self, state: "State", **kwargs):
pass # nothing needed for this modifier
for layer_name in self.targets:
if layer_name.startswith("re:"):
raise ValueError(
"Using regular expressions for layer-wise sparsity "
f"profiles is not permitted. Found {layer_name}"
)
32 changes: 18 additions & 14 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
from torch.nn import Module

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
Expand Down Expand Up @@ -46,26 +45,22 @@ class SparseGPTModifierPyTorch(SparseGPTModifier):
"""

model: Any = None
compressible_layers_: List = None
device_: str = "cuda:0"
finalization_kwargs_: Dict = None
layer_prefix_: Optional[str] = None

def compressible_layers(self) -> List[Module]:
"""
Retrieves the modules corresponding to a list of compressible layer names
:return: list of Pytorch modules to compress
"""
compressible_dict = self.model.get_layers(self.targets)
return [v for _, v in compressible_dict.items()]

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Initialize and run the OBCQ algorithm on the current state
:param state: session state storing input model and calibration data
"""
self._validate_layerwise_sparisity()

if not self.initialized_structure_:
self.on_initialize_structure(state, **kwargs)
if self.quantization_modifier_:
self.quantization_modifier_.initialize(state, **kwargs)
self.finalization_kwargs_ = {}
modifiable_model = state.model
calibration_dataloader = state.data.calib
Expand Down Expand Up @@ -127,10 +122,17 @@ def apply_obcq(
"The 'outputs' key is expected but not found from the "
"return of the bottom compressor"
)

inputs = accum_kwargs["outputs"]
_LOGGER.info(f"\n===== Compressing layer {idx}/{num_layers-1} =====")
layer_sparsity = (
self.sparsity[idx] if isinstance(self.sparsity, List) else self.sparsity
)
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)
args = {
"sparsity": self.sparsity,
"sparsity": layer_sparsity,
"prunen": self.prunen,
"prunem": self.prunem,
"blocksize": self.block_size,
Expand All @@ -153,9 +155,11 @@ def on_finalize(self, state: "State", **kwargs) -> bool:
:param state: un-used, for matching spec of Modifier base class
"""
use_cache = self.finalization_kwargs_.get("use_cache", False)
self.model.apply(torch.quantization.disable_observer)
self.model.config.use_cache = use_cache

if self.quantization_modifier_:
self.quantization_modifier_.finalize(state, **kwargs)

return True

def compress_bottom(
Expand Down
39 changes: 3 additions & 36 deletions src/sparseml/modifiers/obcq/utils/layer_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import Dict, List

import torch
import torch.nn as nn
from torch.nn import Module

from sparseml.modifiers.obcq.utils.sparsegpt import SparseGPT
from sparseml.pytorch.utils.helpers import get_dependency_order
from sparseml.utils.pytorch.module import get_prunable_layers


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,14 +60,8 @@ def compressible_modules(self) -> Dict:
:return: dictionary of compressible modules
"""
quantize = self.args.get("quantize", False)
if quantize:
# The layer names are changed due to quantization modifiers, therefore
# we need a slightly different func to retrieve layers
modules = _find_quant_layers(self.layer)
else:
modules = _find_layers(self.layer)
return modules
compressible_layers = get_prunable_layers(self.layer)
return compressible_layers

def pre_compress_parallel(self, **kwargs) -> Dict:
"""
Expand Down Expand Up @@ -217,30 +211,3 @@ def tmp(_, inp, out):
blocksize=self.args["blocksize"],
)
gpts.free()


def _find_quant_layers(
module, layers=[torch.nn.qat.Conv2d, torch.nn.qat.Linear], name=""
):
res = {}
# search for QAT versions of layers
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res


def _find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(
_find_layers(
child, layers=layers, name=name + "." + name1 if name != "" else name1
)
)
return res
22 changes: 11 additions & 11 deletions src/sparseml/transformers/sparsification/obcq/example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@ metadata:

test_stage:
obcq_modifiers:
QuantizationModifier:
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
post_oneshot_calibration: True
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null
SparseGPTModifier:
sparsity: 0.5
block_size: 128
sequential_update: False
quantize: True
quantize:
QuantizationModifier:
ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"]
post_oneshot_calibration: True
scheme_overrides:
ReLU:
input_activations: null
output_activations: null
LayerNorm:
input_activations: null
output_activations: null
percdamp: 0.01
prunen: 0
prunem: 0
Expand Down
16 changes: 16 additions & 0 deletions src/sparseml/utils/pytorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"get_terminal_layers",
"get_prunable_layers",
"get_quantizable_layers",
"qat_active",
"get_layers_params",
]

Expand Down Expand Up @@ -241,6 +242,21 @@ def get_quantizable_layers(module: Module) -> Dict[str, Module]:
return quantizable


def qat_active(module: Module) -> bool:
"""
Determines if any layers in the model have quantization enabled by checking for
weight_fake_quant attributes
:param module: PyTorch model to check for quantization
:return: True if quantization is active anywhere in the model, False otherwise
"""
for _, layer in module.named_modules():
if isinstance(layer, torch.quantization.FakeQuantize):
return True

return False


def get_layers_params(
targets: Union[str, List[str]], module: Module
) -> Dict[str, ModelParameterizedLayer[Parameter, Module]]:
Expand Down
Loading

0 comments on commit 19188fb

Please sign in to comment.