diff --git a/README.md b/README.md index b381598c3fa..69c5caf8c3c 100644 --- a/README.md +++ b/README.md @@ -128,7 +128,7 @@ More information on installation such as optional dependencies and requirements ### Recipes -To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparamters that should be applied by SparseML. +To enable flexibility, ease of use, and repeatability, SparseML uses a declarative interface called `recipes` for specifying the sparsity-related algorithms and hyperparameters that should be applied by SparseML. `Recipes` are YAML-files formatted as a list of `modifiers`, which encode the instructions for SparseML. Example `modifiers` can be anything from setting the learning rate to encoding the hyperparameters of the gradual magnitude pruning algorithm. The SparseML system parses the `recipes` into a native format for each framework and applies the modifications to the model and training pipeline. diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index 7d83ea2aab6..71fbb3d3e24 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Generic, List, Optional, TypeVar, Union +from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union from sparseml.core.framework import Framework from sparseml.core.framework_object import MultiFrameworkObject @@ -125,6 +125,7 @@ def set_param(self, target: str, param: PT): """ raise NotImplementedError() + @property def layer_prefix(self) -> Optional[str]: """ @@ -141,6 +142,16 @@ def layer_prefix(self, value: Optional[str]): """ self._layer_prefix = value + def get_matching_layer( + self, target: str, name_to_match: str, model: LT + ) -> Optional[Tuple[str, LT]]: + """ + :param target: regex layer name to target when searching model + :param name_to_match: name to match targets to + :param model: model to search for targets + """ + raise NotImplementedError() + def qat_active(self) -> bool: """ Checks if quantization aware training is set up in the model diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index e3bce044748..f3a5701a3fe 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -22,6 +22,7 @@ get_layer, get_layers, get_layers_params, + get_matching_layer, get_param, get_params, qat_active, @@ -101,6 +102,16 @@ def set_param(self, target: str, param: Parameter): """ return set_param(target, param, self.model) + def get_matching_layer( + self, target: str, name_to_match: str, model: Module + ) -> Optional[Tuple[str, Module]]: + """ + :param target: regex layer name to target when searching model + :param name_to_match: name to match targets to + :param model: model to search for targets + """ + return get_matching_layer(target, name_to_match, model) + def qat_active(self) -> bool: """ Checks if quantization aware training is set up in the model diff --git a/src/sparseml/modifiers/__init__.py b/src/sparseml/modifiers/__init__.py index adf250cf344..8fbf0828e64 100644 --- a/src/sparseml/modifiers/__init__.py +++ b/src/sparseml/modifiers/__init__.py @@ -18,3 +18,4 @@ from .obcq import * from .pruning import * from .quantization import * +from .smoothquant import * diff --git a/src/sparseml/modifiers/quantization/pytorch.py b/src/sparseml/modifiers/quantization/pytorch.py index dfa352e1378..1219a73156b 100644 --- a/src/sparseml/modifiers/quantization/pytorch.py +++ b/src/sparseml/modifiers/quantization/pytorch.py @@ -13,8 +13,7 @@ # limitations under the License. import logging -from itertools import cycle -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Optional import torch from torch.nn import Module @@ -35,7 +34,7 @@ raise_if_torch_quantization_not_available, set_quantization_schemes, ) -from sparseml.pytorch.utils import tensors_module_forward, tensors_to_device +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward _LOGGER = logging.getLogger(__name__) @@ -191,26 +190,13 @@ def _calibrate(self, module: Module): module_training = module.training module.eval() - forward_fn: Callable = ( - self.calibration_function_ - if self.calibration_function_ - else tensors_module_forward + run_calibration_forward( + module, + self.calibration_dataloader_, + self.num_calibration_steps, + self.calibration_function_, ) - model_device = next(module.parameters()).device - _dataloader = ( - self.calibration_dataloader_ - if self.num_calibration_steps is None - else cycle(self.calibration_dataloader_) - ) - - for batch_idx, batch in enumerate(_dataloader): - if self.num_calibration_steps and batch_idx >= self.num_calibration_steps: - break - batch = tensors_to_device(batch, model_device) - with torch.no_grad(): - forward_fn(batch, module=module) - if module_training: module.train() else: diff --git a/src/sparseml/modifiers/smoothquant/__init__.py b/src/sparseml/modifiers/smoothquant/__init__.py new file mode 100644 index 00000000000..19953714b09 --- /dev/null +++ b/src/sparseml/modifiers/smoothquant/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import * diff --git a/src/sparseml/modifiers/smoothquant/base.py b/src/sparseml/modifiers/smoothquant/base.py new file mode 100644 index 00000000000..b5d8adcb65f --- /dev/null +++ b/src/sparseml/modifiers/smoothquant/base.py @@ -0,0 +1,183 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Dict, Generic, List, Optional, Tuple, TypeVar + +from pydantic import Field + +from sparseml.core import Modifier +from sparseml.core.model import ModifiableModel +from sparseml.core.model.base import LT +from sparseml.core.state import Event, State + + +VT = TypeVar("VT") # represents a generic vector + +__all__ = ["SmoothQuantScale", "SmoothQuantMapping", "SmoothQuantModifier"] + + +@dataclass +class SmoothQuantScale(Generic[VT]): + """ + Dataclass for storing the channel-wise minimum and maximum values for a layer. This + is updated each forward pass during calibration + + :param min_channel_vals: minimum output value seen so far, per channel + :param max_channel_vals: maximum output value seen so far, per channel + """ + + min_channel_vals: VT + max_channel_vals: VT + + +@dataclass +class SmoothQuantMapping(Generic[LT]): + """ + Dataclass for storing the mapping between an activation layer and the following + weights that must be balanced during smoothing + + :param smooth_name: name of the activation layer + :param smooth_layer: PyTorch module storing the activation layer + :param balance_layers: list of PyTorch modules that smooth_layer feeds into, must be + balanced to offset the smoothing of smooth_layer + """ + + smooth_name: str + smooth_layer: LT + balance_layers: List[LT] + + +class SmoothQuantModifier(Modifier): + """ + Implements the SmoothQuant algorithm from https://arxiv.org/abs/2211.10438. This + modifier performs a channel-wise smoothing of outliers in activations, making them + easier to quantize by reducing the dynamic range. The smoothing is offset by + applying the inverse operation to the next layer of weights, making the weights + slightly more difficult to quantize. + + Because this modifier manipulates the weights of the model, it can only be used in + in one-shot and not during training. Activation ranges are determined by running a + small set of calibration data through the model. + + example recipe: + ```yaml + SmoothQuantModifier: + smoothing_strength: 0.5 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], + [["re:.*fc1"], "re:.*final_layer_norm"] + ] + ignore: ["model.decoder.final_layer_norm"] + ``` + + :param smoothing_strength: alpha, intensity of smoothing to perform (0-1 range) + :param mappings: list activation layers to smooth, and the which layers to offset + the smoothing to for each activation + :param ignore: list of layers to ignore, even if they match a regex in mappings + :param num_calibration_steps: number of samples to use for calibration, or None to + use the whole dataset + """ + + smoothing_strength: float = Field(validation_alias="alpha") + mappings: List[Tuple] + ignore: Optional[List[str]] = None + num_calibration_steps: Optional[int] = None + + resolved_mappings_: Optional[List] = None + scales_: Optional[Dict] = None + + def on_initialize_structure(self, state: State, **kwargs): + pass # nothing needed for this modifier + + def on_initialize(self, state: State, **kwargs) -> bool: + """ + Initialize and run SmoothQuant on the given state + + :param state: state to run SmoothQuant on + :return: True on a successful run, False otherwise + """ + if self.end and self.end != -1: + raise ValueError( + "SmoothQuantModifier can only be applied during one-shot. Expected end" + " to be None or -1, got {}".format(self.end) + ) + if self.start and self.start != -1: + raise ValueError( + "SmoothQuantModifier can only be applied during one-shot. Expected " + "start to be None or -1, got {}".format(self.start) + ) + + self.ignore = [] if not self.ignore else self.ignore + self.resolved_mappings_ = self._resolve_mappings(state.model) + self.scales_ = {} + + def _resolve_mappings(self, model: ModifiableModel) -> List: + """ + Transforms the list of activations to smooth and their corresponding weights + into SmoothQuantMapping objects, resolving regular expressions. + + For each activation in the mapping list, we find the corresponding weight to + balance by searching for the longest substring. For instance, if our balance + weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we + would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and + repeat for model.layer.1 and so on + """ + resolved_mappings = [] + for to_balance, to_smooth in self.mappings: + to_smooth_layers = model.get_layers(to_smooth) + for layer_name, smooth_layer in to_smooth_layers.items(): + if layer_name not in self.ignore: + balance_layers = [] + for balance_suffix in to_balance: + # find the submodule that matches the activation layer + _, balance_layer = model.get_matching_layer( + balance_suffix, layer_name, model.model + ) + if balance_layer: + balance_layers.append(balance_layer) + # each mapping can contain multiple layers to balance, but only + # one layer to smooth + mapping = SmoothQuantMapping( + layer_name, smooth_layer, balance_layers + ) + resolved_mappings.append(mapping) + return resolved_mappings + + def on_start(self, state: State, event: Event, **kwargs): + pass + + def on_update(self, state: State, event: Event, **kwargs): + pass + + def on_end(self, state: State, event: Event, **kwargs): + pass + + def on_event(self, state: State, event: Event, **kwargs): + pass + + def on_finalize(self, state: State, **kwargs) -> bool: + """ + Clean up by clearing the scale and mapping data + + :param state: unused + :return: True + """ + if self.scales_ is not None: + self.scales_.clear() + if self.resolved_mappings_ is not None: + self.resolved_mappings_.clear() + + return True diff --git a/src/sparseml/modifiers/smoothquant/pytorch.py b/src/sparseml/modifiers/smoothquant/pytorch.py new file mode 100644 index 00000000000..7baecf3975b --- /dev/null +++ b/src/sparseml/modifiers/smoothquant/pytorch.py @@ -0,0 +1,192 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Callable, List, Optional + +import torch +from torch.nn import Module + +from sparseml.core import State +from sparseml.core.model.pytorch import ModifiableModelPyTorch +from sparseml.modifiers.smoothquant.base import SmoothQuantModifier, SmoothQuantScale +from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward + + +_LOGGER = logging.getLogger(__name__) + +__all__ = ["SmoothQuantModifierPyTorch"] + + +class SmoothQuantModifierPyTorch(SmoothQuantModifier): + """ + PyTorch implementation of the SmoothQuant algorithm + + :param calibration_function: optional function to use for the forward pass, or None + to use the default tensor_module_forward + """ + + calibration_function: Optional[Callable] = None + hooks_: List = None + + def on_initialize(self, state: State, **kwargs) -> bool: + """ + Initialize and run SmoothQuant on the given state + + :param state: state to run SmoothQuant on + :return: True on a successful run, False otherwise + """ + super(SmoothQuantModifierPyTorch, self).on_initialize(state, **kwargs) + + calibration_dataloader = state.data.calib + self.hooks_ = [] + + self._setup_scale_hooks() + self._calibrate(state.model, calibration_dataloader) + self._apply_smoothing() + + return True + + def on_finalize(self, state: State, **kwargs) -> bool: + """ + Clean up by clearing the CUDA cache + + :param state: unused + :return: True + """ + super(SmoothQuantModifierPyTorch, self).on_finalize(state, **kwargs) + torch.cuda.empty_cache() + + return True + + def _setup_scale_hooks(self): + """ + Attach a forward hook to each activation we want to smooth. This allows us to + calculate the dynamic range during calibration + """ + + def create_hook_fn(layer_name): + def hook_fn(module, inp, out): + # update the per-channel min/max output values seen during calibration + if isinstance(out, tuple): + out = out[0] + + hidden_dim = out.shape[-1] + out = out.view(-1, hidden_dim).abs() + latest_mins = torch.min(out, dim=0)[0] + latest_maxes = torch.max(out, dim=0)[0] + + if layer_name in self.scales_: + self.scales_[layer_name].min_channel_vals = torch.minimum( + self.scales_[layer_name].min_channel_vals, latest_mins + ) + self.scales_[layer_name].max_channel_vals = torch.maximum( + self.scales_[layer_name].max_channel_vals, latest_maxes + ) + else: + self.scales_[layer_name] = SmoothQuantScale( + min_channel_vals=latest_mins, max_channel_vals=latest_maxes + ) + + return hook_fn + + for mapping in self.resolved_mappings_: + name = mapping.smooth_name + layer = mapping.smooth_layer + self.hooks_.append(layer.register_forward_hook(create_hook_fn(name))) + + @torch.no_grad() + def _calibrate(self, model: ModifiableModelPyTorch, calibration_dataloader: List): + """ + Catch the output dynamic ranges of each layer that will be smoothed by running + forward passes with calibration_dataloader + """ + _LOGGER.info("Running SmoothQuant scale calibration...") + if not calibration_dataloader: + raise ValueError( + "Calibration data loader not set, must populate the calib_data field of" + " SparseSession to run the SmoothQuant modifier" + ) + + run_calibration_forward( + model.model, + calibration_dataloader, + self.num_calibration_steps, + self.calibration_function, + ) + + # remove the hooks now that we are done calibrating + for hook in self.hooks_: + hook.remove() + del self.hooks_ + + @torch.no_grad() + def _apply_smoothing(self): + """ + After calibration, apply smoothing to the activations and push the transform + into the following weights by applying the inverse to each balance weight. + + Y = (Xdiag(scales)^(-1) * diag(scales)W) where W is the to_balance weights and + X is the to_smooth weights + + This modifies the weights of the model in-place. + """ + _LOGGER.info("Smoothing activation scales...") + for mapping in self.resolved_mappings_: + activation_scales = ( # get dynamic range for each activation channel + self.scales_[mapping.smooth_name].max_channel_vals + - self.scales_[mapping.smooth_name].min_channel_vals + ) + smooth_layer = mapping.smooth_layer + balance_layers = mapping.balance_layers + + scales = self._calculate_smoothing_scales(balance_layers, activation_scales) + + # invert the smoothing in the following layers + for layer in balance_layers: + layer.weight.mul_(scales.view(1, -1)) + + # apply the smoothing + if smooth_layer.weight.ndim == 1: + smooth_layer.weight.div_(scales) + else: + smooth_layer.weight.div_(scales.view(-1, 1)) + if hasattr(smooth_layer, "bias"): + smooth_layer.bias.div_(scales) + + def _calculate_smoothing_scales( + self, balance_layers: List[Module], activation_scales: torch.Tensor + ) -> List[float]: + """ + Calculate how much smoothing to apply to each channel based on the dynamic + range of the activation and the following weights + + :param balance_layers: layers to offset activation smoothing to + :param activation_scales: channel-wise dynamic range of activation to smooth + :return: channel-wise scales to use for smoothing activation + """ + # get the channel-wise dynamic range for each layer to be balanced + weight_scales = [] + for layer in balance_layers: + scale = layer.weight.abs().max(dim=0, keepdim=True)[0] + weight_scales.append(scale) + weight_scales = 2.0 * torch.cat(weight_scales, dim=0).max(dim=0)[0] + + # calculate the amount of smoothing to apply + # s_j = max(|X_j|)^alpha / max(|W_j|)^(1-alpha) + # where j is the input channel, alpha is smoothing strength + scales = activation_scales.pow(self.smoothing_strength) / weight_scales.pow( + 1 - self.smoothing_strength + ) + return scales diff --git a/src/sparseml/modifiers/utils/pytorch_helpers.py b/src/sparseml/modifiers/utils/pytorch_helpers.py new file mode 100644 index 00000000000..d8309468756 --- /dev/null +++ b/src/sparseml/modifiers/utils/pytorch_helpers.py @@ -0,0 +1,59 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import cycle +from typing import Callable, List, Optional + +import torch +from torch.nn import Module + +from sparseml.pytorch.utils import tensors_module_forward, tensors_to_device + + +def run_calibration_forward( + model: Module, + calibration_dataloader: List, + num_calibration_steps: Optional[int] = None, + calibration_function: Optional[Callable] = None, +): + """ + Helper function used by one-shot modifiers, runs calibration data through a model to + update modifier statistics and trigger hooks + + :param model: PyTorch model to run + :param calibration_dataloader: data to use for calibration + :param num_calibration_steps: number of items in calibration_dataloader to process, + None to process all available data + :param calibration_function: option to pass a custom forward function for model + """ + model.eval() + + forward_fn: Callable = ( + calibration_function if calibration_function else tensors_module_forward + ) + + model_device = next(model.model.parameters()).device + _dataloader = ( + calibration_dataloader + if num_calibration_steps is None + else cycle(calibration_dataloader) + ) + + # run through the calibration data + for batch_idx, batch in enumerate(_dataloader): + if num_calibration_steps and batch_idx >= num_calibration_steps: + break + batch = tensors_to_device(batch, model_device) + with torch.no_grad(): + forward_fn(batch, module=model) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index d3620d503eb..d3b3a7a0e22 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -401,30 +401,27 @@ def collate_fn(batch): _LOGGER.info("Creating model") local_rank = int(os.environ["LOCAL_RANK"]) if args.distributed else None - model, arch_key, maybe_dp_device = _create_model( + model, arch_key = _create_model( arch_key=args.arch_key, local_rank=local_rank, pretrained=args.pretrained, checkpoint_path=args.checkpoint_path, pretrained_dataset=args.pretrained_dataset, - device=device, num_classes=num_classes, ) if args.distill_teacher not in ["self", "disable", None]: _LOGGER.info("Instantiating teacher") - distill_teacher, _, _ = _create_model( + distill_teacher, _ = _create_model( arch_key=args.teacher_arch_key, local_rank=local_rank, pretrained=True, # teacher is always pretrained pretrained_dataset=args.pretrained_teacher_dataset, checkpoint_path=args.distill_teacher, - device=device, num_classes=num_classes, ) else: distill_teacher = args.distill_teacher - device = maybe_dp_device if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -507,7 +504,7 @@ def collate_fn(batch): alpha = 1.0 - args.model_ema_decay alpha = min(1.0, alpha * adjust) model_ema = utils.ExponentialMovingAverage( - model, device=device, decay=1.0 - alpha + model, device=model.device, decay=1.0 - alpha ) manager = checkpoint_manager = None @@ -651,9 +648,17 @@ def log_metrics(tag: str, metrics: utils.MetricLogger, epoch: int, epoch_step: i args, optimizer, checkpoint=checkpoint, manager=manager ) - model_without_ddp = model if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + ddp = True + device = local_rank + else: + ddp = False + + model, device, _ = model_to_device(model, device, ddp) + if distill_teacher is not None: + distill_teacher, _, _ = model_to_device(distill_teacher, device, ddp) + + if args.distributed: model_without_ddp = model.module best_top1_acc = -math.inf @@ -760,7 +765,6 @@ def _create_model( pretrained: Optional[bool] = False, checkpoint_path: Optional[str] = None, pretrained_dataset: Optional[str] = None, - device=None, num_classes=None, ): if not arch_key or arch_key in ModelRegistry.available_keys(): @@ -811,17 +815,7 @@ def _create_model( raise ValueError( f"Unable to find {arch_key} in ModelRegistry or in torchvision.models" ) - ddp = False - if local_rank is not None: - torch.cuda.set_device(local_rank) - device = local_rank - ddp = True - model, device, _ = model_to_device( - model=model, - device=device, - ddp=ddp, - ) - return model, arch_key, device + return model, arch_key def _get_lr_scheduler(args, optimizer, checkpoint=None, manager=None): diff --git a/src/sparseml/transformers/sparsification/obcq/example.yaml b/src/sparseml/transformers/sparsification/obcq/example.yaml index 6f899535d45..ddf9d6ee3ce 100644 --- a/src/sparseml/transformers/sparsification/obcq/example.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example.yaml @@ -5,21 +5,28 @@ metadata: test_stage: obcq_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.5 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*self_attn_layer_norm"], + [["re:.*fc1"], "re:.*final_layer_norm"] + ] + ignore: ["model.decoder.final_layer_norm"] + QuantizationModifier: + ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] + post_oneshot_calibration: True + scheme_overrides: + ReLU: + input_activations: null + output_activations: null + LayerNorm: + input_activations: null + output_activations: null SparseGPTModifier: sparsity: 0.5 block_size: 128 sequential_update: False - quantize: - QuantizationModifier: - ignore: ["lm_head", "Embedding", "OPTLearnedPositionalEmbedding", "QuantizableBatchMatMul", "BMMLeftInput_QK", "BMMRightInput_QK", "BMMOutput_QK", "BMMLeftInput_PV", "BMMRightInput_PV", "BMMOutput_PV"] - post_oneshot_calibration: True - scheme_overrides: - ReLU: - input_activations: null - output_activations: null - LayerNorm: - input_activations: null - output_activations: null + quantize: True percdamp: 0.01 prunen: 0 prunem: 0 diff --git a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml index 6cef1ebeb40..ea3f4ae5cd1 100644 --- a/src/sparseml/transformers/sparsification/obcq/example_llama.yaml +++ b/src/sparseml/transformers/sparsification/obcq/example_llama.yaml @@ -1,5 +1,11 @@ test_stage: obcq_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.5 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] + ] QuantizationModifier: ignore: - LlamaRotaryEmbedding diff --git a/src/sparseml/transformers/sparsification/obcq/obcq.py b/src/sparseml/transformers/sparsification/obcq/obcq.py index 7a60b14a5b8..8a973cdb803 100644 --- a/src/sparseml/transformers/sparsification/obcq/obcq.py +++ b/src/sparseml/transformers/sparsification/obcq/obcq.py @@ -19,6 +19,7 @@ from typing import Optional from torch.nn import Module +from transformers import AutoConfig import sparseml.core.session as session_manager from sparseml.core.framework import Framework @@ -36,7 +37,7 @@ _LOGGER = logging.getLogger(__name__) SUPPORTED_DATASETS = ["wikitext2", "ptb", "c4", "open_platypus"] -SUPPORTED_MODELS = ["opt", "llama"] +SUPPORTED_MODELS = ["opt", "llama", "mistral"] def one_shot( @@ -70,14 +71,21 @@ def one_shot( if deploy_dir.exists(): raise RuntimeError(f"deploy_dir={deploy_dir} already exists") + # Load the configuration from the model path + config = AutoConfig.from_pretrained(model_path) + model_type = config.model_type.lower() + model_loader_fn = None forward_fn = None - if "opt" in model_path.lower(): + if "opt" in model_type: model_loader_fn = SparseCasualLM.opt_model_from_pretrained forward_fn = opt_forward - elif "llama" in model_path.lower(): + elif "llama" in model_type: model_loader_fn = SparseCasualLM.llama_model_from_pretrained forward_fn = llama_forward + elif "mistral" in model_type: + model_loader_fn = SparseCasualLM.auto_model_from_pretrained + forward_fn = llama_forward else: raise ValueError(f"model_path={model_path} should be one of {SUPPORTED_MODELS}") model = model_loader_fn(model_path) diff --git a/src/sparseml/transformers/utils/model.py b/src/sparseml/transformers/utils/model.py index 3f89f7b4127..4e23d91c1e7 100644 --- a/src/sparseml/transformers/utils/model.py +++ b/src/sparseml/transformers/utils/model.py @@ -462,6 +462,19 @@ def llama_model_from_pretrained(model_path: str) -> torch.nn.Module: model.seqlen = model.config.max_position_embeddings return model + @staticmethod + def auto_model_from_pretrained(model_path: str) -> torch.nn.Module: + """ + Load a pretrained model using auto from the specified hugging face path + + :param model_path: hugging face path to model + :return: loaded pretrained model + """ + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto") + model.eval() + model.seqlen = model.config.max_position_embeddings + return model + def get_shared_tokenizer_src(student: Module, teacher: Optional[Module]) -> str: """ diff --git a/src/sparseml/utils/pytorch/module.py b/src/sparseml/utils/pytorch/module.py index fb3956016b0..bc95c6dc888 100644 --- a/src/sparseml/utils/pytorch/module.py +++ b/src/sparseml/utils/pytorch/module.py @@ -16,8 +16,9 @@ Utility / helper functions """ +import difflib import re -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from packaging import version @@ -67,6 +68,7 @@ "get_quantizable_layers", "qat_active", "get_layers_params", + "get_matching_layer", ] @@ -271,3 +273,33 @@ def get_layers_params( parameterized_layers[name] = param_layer return parameterized_layers + + +def get_matching_layer( + target: str, name_to_match: str, module: Module +) -> Optional[Tuple[str, Module]]: + """ + Given a target regex, find the layer name in the module that most closely matches + the name_to_match string. This is used to matches submodules in the same layer, for + instance matching "re.*k_proj" to "model.decoder.layer.0.q_proj" to find the k_proj + that exists in layer 0. + + :param target: regex to search for + :param name_to_match: full layer name to match to, should exist in module + :param module: module to search for target in + :return: Tuple containing the layer name and module that fits the target regex and + best matches name_to_match, or None if no match can be found + """ + potential_matches = get_layers(target, module) + largest_substring = 0 + match = None + for name, module in potential_matches.items(): + seq_matcher = difflib.SequenceMatcher(None, name, name_to_match) + _, _, match_length = seq_matcher.find_longest_match( + 0, len(name), 0, len(name_to_match) + ) + if match_length > largest_substring: + match = (name, module) + largest_substring = match_length + + return match diff --git a/tests/sparseml/transformers/obcq/test_obcq.py b/tests/sparseml/transformers/obcq/test_obcq.py new file mode 100644 index 00000000000..f41aaafd59e --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_obcq.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sparseml.modifiers.obcq.utils.helpers import ppl_eval_general +from sparseml.transformers.data import TransformersDataset +from sparseml.transformers.sparsification.obcq.obcq import one_shot +from sparseml.transformers.sparsification.obcq.utils.helpers import llama_forward + + +def test_obcq_tinystories(): + tiny_model_path = "Xenova/llama2.c-stories15M" + device = "cuda:0" + + # test recipe with 50% sparsity, quantization and smoothquant + tiny_model = one_shot( + model_path=tiny_model_path, + dataset_name="open_platypus", + num_samples=64, + device=device, + recipe_file="tests/sparseml/transformers/obcq/test_tiny.yaml", + ) + + dataset = TransformersDataset.load_from_registry( + "wikitext2", + model=tiny_model_path, + seqlen=tiny_model.seqlen, + nsamples=64, + seed=0, + split="test", + ) + test_data = dataset.loader + perplexity = ppl_eval_general( + llama_forward, tiny_model, test_data, device, max_samples_per_iteration=8 + ) + + # we aren't expecting good results from this tiny model, but this should catch any + # egregious errors with the OBCQ algorithm + assert perplexity < 10000.0 diff --git a/tests/sparseml/transformers/obcq/test_tiny.yaml b/tests/sparseml/transformers/obcq/test_tiny.yaml new file mode 100644 index 00000000000..95771adb20a --- /dev/null +++ b/tests/sparseml/transformers/obcq/test_tiny.yaml @@ -0,0 +1,43 @@ +test_stage: + obcq_modifiers: + SmoothQuantModifier: + smoothing_strength: 0.5 + mappings: [ + [["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"], + [["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"] + ] + QuantizationModifier: + ignore: + - LlamaRotaryEmbedding + - LlamaRMSNorm + - SiLUActivation + - model.layers.0.mlp.down_proj + - model.layers.1.mlp.down_proj + - model.layers.2.mlp.down_proj + - model.layers.3.mlp.down_proj + - model.layers.4.mlp.down_proj + - model.layers.5.mlp.down_proj + post_oneshot_calibration: True + scheme_overrides: + Embedding: + input_activations: null + weights: + num_bits: 8 + symmetric: False + SparseGPTModifier: + sparsity: 0.5 + block_size: 128 + sequential_update: False + quantize: True + percdamp: 0.01 + prunen: 0 + prunem: 0 + targets: [ + "model.layers.0", + "model.layers.1", + "model.layers.2", + "model.layers.3", + "model.layers.4", + "model.layers.5" + ] + target_ids: ["attention_mask", "position_ids"] \ No newline at end of file