Skip to content

Commit

Permalink
Split WandaPruningModifier and SparseGPTModifier
Browse files Browse the repository at this point in the history
Make sparsegpt not inherit from wanda modifier
Decouple SparseGPTModifierPyTorch from WandaPruningModifier
Fix docstrings
  • Loading branch information
rahul-tuli committed May 17, 2024
1 parent d7db1e3 commit 4230fb7
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 14 deletions.
69 changes: 65 additions & 4 deletions src/sparseml/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@
import logging
from typing import Any, Dict, List, Optional, Union

from sparseml.core import Modifier
from sparseml.core.factory import ModifierFactory
from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.pruning.wanda.base import WandaPruningModifier


__all__ = ["SparseGPTModifier"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifier(WandaPruningModifier):
class SparseGPTModifier(Modifier):
"""
Modifier for applying the one-shot OBCQ algorithm to a model
Expand All @@ -41,19 +42,35 @@ class SparseGPTModifier(WandaPruningModifier):
- on_finalize
- LayerCompressor.revert_layer_wrappers()
:param sparsity: Sparsity to compress model to
:param mask_structure: String to define the structure of the mask to apply.
Must be of the form N:M where N, M are integers that define a custom block
shape. Defaults to 0:0 which represents an unstructured mask.
:param sequential_update: Whether or not to update weights sequentially by layer,
True saves on GPU memory
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model
:param block_size: Used to determine number of columns to compress in one pass
:param quantize: Whether or not to quantize weights during SparseGPT. Set to
True to quantize using an existing quantization modifier, or pass in the
configuration for a quantization modifier if one does not already exist
in the recipe
:param sparsity: Sparsity to compress model to
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
"""

sparsity: Union[float, List[float]] = 0.0
sparsity_profile: Optional[str] = None
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None
mask_structure: str = "0:0"
sequential_update: Optional[bool] = False
targets: Union[str, List[str], None] = None
compressible_layers_: Optional[List] = None
prunen_: Optional[int] = None
prunem_: Optional[int] = None
block_size: int = 128
quantize: Union[bool, Dict] = False
sparsity: Union[float, List[float]] = 0.0
dampening_frac: Optional[float] = 0.01
quantization_modifier_: Any = None

Expand Down Expand Up @@ -112,6 +129,39 @@ def on_initialize_structure(self, state: State, **kwargs):
if self.quantization_modifier_:
self.quantization_modifier_.on_initialize_structure(state, **kwargs)

def compressible_layers(self) -> Dict:
"""
Retrieves the modules corresponding to a list of
compressible layer names
:precondition: self.model is set and is a `ModifiableModel`
:precondition: The `ModifiableModel` implements a `get_layers`
method
:return: dictionary of modules to compress
"""
if not isinstance(self.model, ModifiableModel):
raise ValueError(
"`self.model` must be a ModifiableModel to use "
f"the {self.__class__.__qualname__} modifier but got "
f"{type(self.model)} instead"
)

return self.model.get_layers(self.targets)

def _validate_layerwise_sparsity(self):
if isinstance(self.sparsity, float):
# single sparsity will be applied to all layers
return

target_layers = list(self.compressible_layers_.keys())

if len(target_layers) != len(self.sparsity):
raise ValueError(
"Number of layer targets must match the number of "
f"sparsities. Got {len(target_layers)} layers and "
f"{len(self.sparsity)} sparsities"
)

def _build_quant_modifier_from_dict(self, quant_config, framework):
modifier_type = list(quant_config.keys())[0]
modifier_args = quant_config[modifier_type]
Expand All @@ -122,3 +172,14 @@ def _build_quant_modifier_from_dict(self, quant_config, framework):
allow_experimental=True,
**modifier_args,
)

def on_finalize(self, state: State, **kwargs):
"""
Nothing to do on finalize, on this level.
Quantization Modifier if any will be finalized in the subclass
:param state: session state storing input model and calibration data
:param kwargs: additional arguments
:return: True
"""
return True
216 changes: 209 additions & 7 deletions src/sparseml/modifiers/obcq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,27 @@
# limitations under the License.

import logging
from typing import List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm

from sparseml.core.model import ModifiableModel
from sparseml.core.state import State
from sparseml.modifiers.obcq.base import SparseGPTModifier
from sparseml.modifiers.obcq.utils.sgpt_wrapper import SparseGptWrapper
from sparseml.modifiers.pruning.wanda.pytorch import WandaPruningModifierPyTorch
from sparseml.modifiers.utils.layer_compressor import LayerCompressor
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
from sparseml.utils.pytorch.module import get_prunable_layers


__all__ = ["SparseGPTModifierPyTorch"]

_LOGGER = logging.getLogger(__name__)


class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
class SparseGPTModifierPyTorch(SparseGPTModifier):
"""
Pytorch implementation of SparseGPT
Expand All @@ -40,14 +46,25 @@ class SparseGPTModifierPyTorch(WandaPruningModifierPyTorch, SparseGPTModifier):
- run_calibration_forward()
- LayerCompressor.compress()
- LayerCompressor.post_compress()
- on_finalize
- LayerCompressor.revert_layer_wrappers()
- LayerCompressor.revert_layer_wrappers()
| Sample yaml:
| test_stage:
| obcq_modifiers:
| SparseGPTModifier:
| sparsity: 0.5
| mask_structure: "2:4"
| sequential_update: True
| dampening_frac: 0.001
| targets: __ALL__
| block_size: 128
| quantize: False
:param model: Pytorch model to perform OBCQ on, in-place
"""

model: Optional[ModifiableModel] = None
layer_compressors: List = None
layer_compressors_: Optional[List[Any]] = None

def on_initialize(self, state: "State", **kwargs) -> bool:
"""
Expand All @@ -65,7 +82,99 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
"quantization must be enabled."
)

return super(SparseGPTModifierPyTorch, self).on_initialize(state, **kwargs)
modifiable_model = state.model
calibration_dataloader = state.data.calib

if self.targets is None:
# if no targets are provided, default to the modules that shouldn't be
# split by FSDP. For Transformers models this is equivalent to the
# decoder layers (ie LlamaDecoderLayer)
self.targets = modifiable_model.get_no_split_params()

self.initialize_compression(modifiable_model, calibration_dataloader)
self.apply_compression(calibration_dataloader)

return True

def initialize_compression(
self,
model: ModifiableModel,
dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None,
):
"""
Setup for WANDA, initializes the model, device,
and other parameters, also initilializes the
compressible layers of model, and sets the device
:param model: model to initialize for compression
"""
self.model = model
self.compressible_layers_ = self.compressible_layers()
self.model = self.model.model
self.layer_compressors_ = []
self._infer_mask_block_size()

if self.sparsity_profile is not None and self.sparsity_profile.lower() == "owl":
_LOGGER.info(
"Inferring layer-wise sparsities from "
f"{len(dataloader)} calibration samples..."
)
self.sparsity = self._infer_layer_sparsity(dataloader)
self._validate_layerwise_sparsity()

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
_LOGGER.info(f"Preparing {name} for compression")
if isinstance(self.sparsity, Dict):
layer_sparsity = self.sparsity[name]
elif isinstance(self.sparsity, List):
layer_sparsity = self.sparsity[idx]
else: # float
layer_sparsity = self.sparsity
args = self._pruning_arguments(layer_sparsity)
comp_cls = self._compression_class()
compressor = LayerCompressor(comp_cls, self.model, layer, idx, name, args)
if not self.sequential_update:
# add all batch processing hooks before the forward pass
compressor.pre_compress()
self.layer_compressors_.append(compressor)

@torch.no_grad()
def apply_compression(
self, dataloader: Optional[Iterable[Tuple[List, Dict[str, Any]]]] = None
) -> Dict:
"""
Run Wanda on the loaded model, using dataloader as calibration data
:param dataloader: calibration data for WANDA
"""
class_name = self.__class__.__name__.replace("PyTorch", "")
_LOGGER.info(
f"Running {class_name} calibration with " f"{len(dataloader)} samples..."
)
if not self.sequential_update:
# in non-sequential mode we run one forward batch for all modules
run_calibration_forward(self.model, dataloader, mask_padding=True)

num_layers = len(self.compressible_layers_)
for idx, layer_compressor in enumerate(self.layer_compressors_):
layer_sparsity = layer_compressor.args["sparsity"]
_LOGGER.info(
f"\n===== Compressing layer {idx+1}/{num_layers} "
f"to sparsity {layer_sparsity} ====="
)

# Prune/quantize using SparseGPT
if self.sequential_update:
# in sequential mode we run one forward pass for each module we
# want to compress, this will be really slow but allows compression in
# earlier layers to affect later layers
layer_compressor.pre_compress()
_LOGGER.info(f"Calibrating {layer_compressor.name}...")
run_calibration_forward(self.model, dataloader, mask_padding=True)
layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()
torch.cuda.empty_cache()

def on_finalize(self, state: "State", **kwargs) -> bool:
"""
Expand Down Expand Up @@ -98,3 +207,96 @@ def _compression_class(self):
:return: wrapper class used for root modules of this compression class
"""
return SparseGptWrapper

def _infer_mask_block_size(self):
"""
Infer the mask block size from the mask structure.
Parses mask_structure of the form N:M where N, M are integers that
define a custom block shape; and sets prunen_ and prunem_ accordingly.
:post-condition: prunen_ and prunem_ are set
"""
if self.mask_structure is None:
raise ValueError("mask_structure must be defined")

self.prunen_, self.prunem_ = list(map(int, self.mask_structure.split(":")))

def _infer_layer_sparsity(self, calibration_dataloader):
acts = _get_activations(self.model, calibration_dataloader)
sparsegpt_groups = {}
for name, layer in self.compressible_layers_.items():
prunable_layers = get_prunable_layers(layer)
z = [
m.weight.abs() * acts[f"{name}.{n}"].unsqueeze(0)
for n, m in prunable_layers.items()
]
sparsegpt_groups[name] = torch.cat([item.flatten().cpu() for item in z])

acts = None
del acts
torch.cuda.empty_cache()

outlier_ratios = {}
for group in sparsegpt_groups:
threshold = torch.mean(sparsegpt_groups[group]) * self.owl_m
outlier_ratios[group] = (
100
* (sparsegpt_groups[group] > threshold).sum().item()
/ sparsegpt_groups[group].numel()
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
for k in outlier_ratios:
outlier_ratios[k] = (outlier_ratios[k] - outlier_ratios_arr.min()) * (
1
/ (outlier_ratios_arr.max() - outlier_ratios_arr.min())
* self.owl_lmbda
* 2
)
outlier_ratios_arr = np.array([outlier_ratios[k] for k in outlier_ratios])
sparsities = {
k: 1
- (
outlier_ratios[k]
- np.mean(outlier_ratios_arr)
+ (1 - float(self.sparsity))
)
for k in outlier_ratios
}
_LOGGER.info(f"OWL sparsities for sp={self.sparsity} are:")
for k in sparsities:
_LOGGER.info(f"Sparsity for {k}: {sparsities[k]}")
return sparsities


@torch.no_grad()
def _get_activations(model, data_loader, nsamples=128):
import functools

model.eval()
acts = {}

def save_acts(module, input, name):
if isinstance(input, tuple):
input = input[0]
if name not in acts:
acts[name] = 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()
else:
acts[name] += 1.0 / nsamples * input.detach().pow(2).sum(dim=(0, 1)).sqrt()

hooks = []
for name, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear) and "lm_head" not in name:
hooks.append(
mod.register_forward_pre_hook(functools.partial(save_acts, name=name))
)
device = next(model.parameters()).device
for batch in tqdm(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
model(**batch)
batch = None
torch.cuda.empty_cache()

for h in hooks:
h.remove()

return acts
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ class WandaPruningModifier(Modifier):
- run_calibration_forward()
- LayerCompressor.compress()
- LayerCompressor.post_compress()
- LayerCompressor.revert_layer_wrappers()
- on_finalize
- LayerCompressor.revert_layer_wrappers()
:param sparsity: Sparsity to compress model to
:param mask_structure: String to define the structure of the mask to apply.
Expand Down
Loading

0 comments on commit 4230fb7

Please sign in to comment.