Skip to content

Commit

Permalink
Merge branch 'main' into sa/fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 29, 2024
2 parents 3aa99d1 + c076d41 commit 0f1a839
Show file tree
Hide file tree
Showing 65 changed files with 340 additions and 162 deletions.
34 changes: 34 additions & 0 deletions examples/llama7b_sparse_quantized/2:4_w4a16_group-128_recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
sparsity_stage:
run_type: oneshot
sparsity_modifiers:
SparseGPTModifier:
sparsity: 0.5
mask_structure: "2:4"
sequential_update: false
finetuning_stage:
run_type: train
finetuning_modifiers:
ConstantPruningModifier:
targets: [
're:.*q_proj.weight',
're:.*k_proj.weight',
're:.*v_proj.weight',
're:.*o_proj.weight',
're:.*gate_proj.weight',
're:.*up_proj.weight',
're:.*down_proj.weight',
]
start: 0
quantization_stage:
run_type: oneshot
quantization_modifiers:
GPTQModifier:
sequential_update: false
config_groups:
group_0:
weights:
num_bits: 4
type: "int"
symmetric: true
strategy: "channel"
targets: ["Linear"]
5 changes: 5 additions & 0 deletions examples/llama7b_sparse_quantized/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ from sparseml import SparseAutoModelForCausalLM
model = SparseAutoModelForCausalLM.from_pretrained(output_dir, torch_dtype=torch.bfloat16)
model.save_pretrained(compressed_output_dir, save_compressed=True)
```

### Custom Quantization
The current repo supports multiple quantization techniques configured using a recipe. Supported strategies are `tensor`, `group` and `channel`.
The above recipe (`2:4_w4a16_recipe.yaml`) uses channel-wise quantization specified by `strategy: "channel"` in its config group.
To use quantize per tensor, change strategy from `channel` to `tensor`. To use group size quantization, change from `channel` to `group` and specify its value, say 128, by including `group_size: 128`. Group size quantization example is shown in `2:4_w4a16_group-128_recipe.yaml`
2 changes: 0 additions & 2 deletions src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@
# limitations under the License.

# flake8: noqa

from .base import *
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _build_quant_modifier(self, framework):
)
quant_args["config_groups"] = {"config_group_0": default_quant_scheme}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"vLLMQuantizationModifier": quant_args}
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config, framework)

def compressible_layers(self) -> Dict:
Expand Down
3 changes: 2 additions & 1 deletion src/sparseml/modifiers/quantization/gptq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from sparseml.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper
from sparseml.modifiers.utils.layer_compressor import LayerCompressor
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
from sparseml.utils.fsdp.context import fix_fsdp_module_name


__all__ = ["GPTQModifierPyTorch"]
Expand Down Expand Up @@ -116,6 +117,7 @@ def initialize_compression(
self.layer_compressors_ = []

for idx, (name, layer) in enumerate(self.compressible_layers_.items()):
name = fix_fsdp_module_name(name)
_LOGGER.info(f"Preparing {name} for compression")
args = self._pruning_arguments()
comp_cls = self._compression_class()
Expand Down Expand Up @@ -174,7 +176,6 @@ def _pruning_arguments(self):
"""
Gather the parameters needed for root module compression in a dict
:param sparsity: target sparsity
:return: dict of params for pruning
"""
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,6 @@ def fasterprune(
quant_scheme.weights,
)
else: # strategy == QuantizationStrategy.GROUP
# TODO: for grouped quantization its always 3d but the last
# dim is always 1. Can we just make it 2d instead and avoid?
scale = scale[:, :, 0]
zero_point = zero_point[:, :, 0]

# get the group index for the current column
column_idx = i1 + i
input_dim_group = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from sparseml.core import Event, Modifier


__all__ = ["vLLMQuantizationModifier"]
__all__ = ["QuantizationModifier"]


class vLLMQuantizationModifier(Modifier):
class QuantizationModifier(Modifier):
"""
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.quantization.quantization.base import QuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)


class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):
class QuantizationModifierPyTorch(QuantizationModifier):
"""
PyTorch specific implementation of vLLMQuantizationModifier
PyTorch specific implementation of QuantizationModifier
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/modifiers/quantization_legacy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import *
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
from sparseml.core import Event, Modifier


__all__ = ["QuantizationModifier"]
__all__ = ["LegacyQuantizationModifier"]


class QuantizationModifier(Modifier):
class LegacyQuantizationModifier(Modifier):
"""
Enables quantization aware training (QAT) for a given module or its submodules
After the start epoch, the specified module(s) forward pass will emulate
quantized execution and the modifier will be enabled until training is completed.
| Sample yaml:
| QuantizationModifier:
| LegacyQuantizationModifier:
| start: 0.0
| scheme:
| input_activations:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import logging
import os

from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
from torch.nn import Module

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
from sparseml.modifiers.quantization_legacy.base import LegacyQuantizationModifier
from sparseml.modifiers.quantization_legacy.modification import modify_model
from sparseml.modifiers.quantization_legacy.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import (
from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)
from sparseml.modifiers.quantization.utils.quantize import (
from sparseml.modifiers.quantization_legacy.utils.quantize import (
convert_module_qat_from_schemes,
raise_if_torch_quantization_not_available,
set_quantization_schemes,
Expand All @@ -42,7 +42,7 @@
_LOGGER = logging.getLogger(__name__)


class QuantizationModifierPyTorch(QuantizationModifier):
class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier):
"""
Pytorch-specific implementation of quantization modifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch import quantization as torch_quantization
from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU

from sparseml.modifiers.quantization.utils.quantization_scheme import (
from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import (
QuantizationArgs,
QuantizationScheme,
get_observer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
except Exception:
torch_quantization = None

from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper
from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import (
FakeQuantizeWrapper,
)


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@
from packaging import version
from torch.nn import Identity, Module

from sparseml.modifiers.quantization.utils.constants import (
from sparseml.modifiers.quantization_legacy.utils.constants import (
FUSED_MODULE_NAMES,
NON_QUANTIZABLE_MODULE_NAMES,
)
from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper
from sparseml.modifiers.quantization.utils.helpers import (
from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import (
FakeQuantizeWrapper,
)
from sparseml.modifiers.quantization_legacy.utils.helpers import (
QATWrapper,
configure_module_default_qconfigs,
prepare_embeddings_qat,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme
from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import (
QuantizationScheme,
)
from sparseml.pytorch.utils import get_layer
from sparseml.utils.fsdp.context import fix_fsdp_module_name

Expand Down
4 changes: 4 additions & 0 deletions src/sparseml/pytorch/utils/sparsification.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ def __init__(
self.state_dict = state_dict

if self.state_dict is not None:
# when analyzing an FSDP model, the state_dict does not differentiate
# between trainable and non-trainable parameters
# (e.g. it can contain buffers) this means that the
# self.trainable_parameters may be overestimated
self.trainable_params = [param for _, param in state_dict.items()]
else:
self.trainable_params = list(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def save_pretrained_wrapper(
# check if we are in the old quantization framework
if qat_active(model) and not is_model_quantized(model):
_LOGGER.info(
"Compression for models quantized with QuantizationModifer is not "
"supported. Save will be run without compression and no sparsity "
"statistics will be calculated. To save a quantized model in a "
"compressed state please use vLLMQuantizationModifier instead."
"Compression for models quantized with LegacyQuantizationModifer "
"is not supported. Save will be run without compression and no "
"sparsity statistics will be calculated. To save a quantized model "
"in a compressed state please use QuantizationModifier instead."
)

original_save_pretrained.__get__(model, model_class)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
from torch import nn
from transformers.models.bert.modeling_bert import BertSelfAttention

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QATMatMul,
)
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
MultiHeadSelfAttention,
)

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QATMatMul,
)
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
repeat_kv,
)

from sparseml.modifiers.quantization.modification.modification_objects import (
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QuantizableIdentity,
QuantizableMatMul,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
repeat_kv,
)

from sparseml.modifiers.quantization.modification.modification_objects import (
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QuantizableIdentity,
QuantizableMatMul,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
from torch import nn
from transformers.models.mobilebert.modeling_mobilebert import MobileBertEmbeddings

from sparseml.modifiers.quantization.modification.modification_objects import QATLinear
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QATLinear,
)
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from torch import nn
from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2

from sparseml.modifiers.quantization.modification.modification_objects import (
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QuantizableBatchMatmul,
QuantizableIdentity,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from transformers.file_utils import WEIGHTS_NAME

from compressed_tensors.compressors import ModelCompressor
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization_legacy.modification import modify_model
from sparseml.pytorch.model_load.helpers import (
apply_recipe_structure_to_model,
log_model_load,
Expand Down
Loading

0 comments on commit 0f1a839

Please sign in to comment.