Skip to content

Commit

Permalink
Replace Quantization Modifier (#2307)
Browse files Browse the repository at this point in the history
* convert old modifier to legacy

* redo folder structure

* fixing imports

* update import

* fix imports
  • Loading branch information
Sara Adkins authored May 29, 2024
1 parent c530bbf commit 1fcec3f
Show file tree
Hide file tree
Showing 50 changed files with 129 additions and 82 deletions.
2 changes: 0 additions & 2 deletions src/sparseml/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@
# limitations under the License.

# flake8: noqa

from .base import *
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _build_quant_modifier(self, framework):
)
quant_args["config_groups"] = {"config_group_0": default_quant_scheme}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"vLLMQuantizationModifier": quant_args}
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config, framework)

def compressible_layers(self) -> Dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from sparseml.core import Event, Modifier


__all__ = ["vLLMQuantizationModifier"]
__all__ = ["QuantizationModifier"]


class vLLMQuantizationModifier(Modifier):
class QuantizationModifier(Modifier):
"""
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.quantization.quantization.base import QuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)


class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):
class QuantizationModifierPyTorch(QuantizationModifier):
"""
PyTorch specific implementation of vLLMQuantizationModifier
PyTorch specific implementation of QuantizationModifier
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/modifiers/quantization_legacy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import *
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
from sparseml.core import Event, Modifier


__all__ = ["QuantizationModifier"]
__all__ = ["LegacyQuantizationModifier"]


class QuantizationModifier(Modifier):
class LegacyQuantizationModifier(Modifier):
"""
Enables quantization aware training (QAT) for a given module or its submodules
After the start epoch, the specified module(s) forward pass will emulate
quantized execution and the modifier will be enabled until training is completed.
| Sample yaml:
| QuantizationModifier:
| LegacyQuantizationModifier:
| start: 0.0
| scheme:
| input_activations:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import logging
import os

from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)


_LOGGER = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
from torch.nn import Module

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
from sparseml.modifiers.quantization_legacy.base import LegacyQuantizationModifier
from sparseml.modifiers.quantization_legacy.modification import modify_model
from sparseml.modifiers.quantization_legacy.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import (
from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)
from sparseml.modifiers.quantization.utils.quantize import (
from sparseml.modifiers.quantization_legacy.utils.quantize import (
convert_module_qat_from_schemes,
raise_if_torch_quantization_not_available,
set_quantization_schemes,
Expand All @@ -42,7 +42,7 @@
_LOGGER = logging.getLogger(__name__)


class QuantizationModifierPyTorch(QuantizationModifier):
class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier):
"""
Pytorch-specific implementation of quantization modifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch import quantization as torch_quantization
from torch.nn import BatchNorm2d, Conv2d, Embedding, Module, ReLU

from sparseml.modifiers.quantization.utils.quantization_scheme import (
from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import (
QuantizationArgs,
QuantizationScheme,
get_observer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
except Exception:
torch_quantization = None

from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper
from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import (
FakeQuantizeWrapper,
)


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,21 @@
from packaging import version
from torch.nn import Identity, Module

from sparseml.modifiers.quantization.utils.constants import (
from sparseml.modifiers.quantization_legacy.utils.constants import (
FUSED_MODULE_NAMES,
NON_QUANTIZABLE_MODULE_NAMES,
)
from sparseml.modifiers.quantization.utils.fake_quant_wrapper import FakeQuantizeWrapper
from sparseml.modifiers.quantization.utils.helpers import (
from sparseml.modifiers.quantization_legacy.utils.fake_quant_wrapper import (
FakeQuantizeWrapper,
)
from sparseml.modifiers.quantization_legacy.utils.helpers import (
QATWrapper,
configure_module_default_qconfigs,
prepare_embeddings_qat,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import QuantizationScheme
from sparseml.modifiers.quantization_legacy.utils.quantization_scheme import (
QuantizationScheme,
)
from sparseml.pytorch.utils import get_layer
from sparseml.utils.fsdp.context import fix_fsdp_module_name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def save_pretrained_wrapper(
# check if we are in the old quantization framework
if qat_active(model) and not is_model_quantized(model):
_LOGGER.info(
"Compression for models quantized with QuantizationModifer is not "
"supported. Save will be run without compression and no sparsity "
"statistics will be calculated. To save a quantized model in a "
"compressed state please use vLLMQuantizationModifier instead."
"Compression for models quantized with LegacyQuantizationModifer "
"is not supported. Save will be run without compression and no "
"sparsity statistics will be calculated. To save a quantized model "
"in a compressed state please use QuantizationModifier instead."
)

original_save_pretrained.__get__(model, model_class)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@
from torch import nn
from transformers.models.bert.modeling_bert import BertSelfAttention

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QATMatMul,
)
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
MultiHeadSelfAttention,
)

from sparseml.modifiers.quantization.modification.modification_objects import QATMatMul
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QATMatMul,
)
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
repeat_kv,
)

from sparseml.modifiers.quantization.modification.modification_objects import (
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QuantizableIdentity,
QuantizableMatMul,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
repeat_kv,
)

from sparseml.modifiers.quantization.modification.modification_objects import (
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QuantizableIdentity,
QuantizableMatMul,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@
from torch import nn
from transformers.models.mobilebert.modeling_mobilebert import MobileBertEmbeddings

from sparseml.modifiers.quantization.modification.modification_objects import QATLinear
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QATLinear,
)
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
from torch import nn
from transformers.models.opt.modeling_opt import OPTAttention, OptFlashAttention2

from sparseml.modifiers.quantization.modification.modification_objects import (
from sparseml.modifiers.quantization_legacy.modification.modification_objects import (
QuantizableBatchMatmul,
QuantizableIdentity,
)
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparseml.pytorch.utils.helpers import swap_modules
from sparseml.transformers.sparsification.modification.base import (
check_transformers_version,
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/transformers/sparsification/sparse_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from transformers.file_utils import WEIGHTS_NAME

from compressed_tensors.compressors import ModelCompressor
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization_legacy.modification import modify_model
from sparseml.pytorch.model_load.helpers import (
apply_recipe_structure_to_model,
log_model_load,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

import pytest

from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.modification.registry import ModificationRegistry
from sparseml.modifiers.quantization_legacy.modification import modify_model
from sparseml.modifiers.quantization_legacy.modification.registry import (
ModificationRegistry,
)
from sparsezoo.utils.registry import _ALIAS_REGISTRY, _REGISTRY, standardize_lookup_name


Expand Down
10 changes: 5 additions & 5 deletions tests/sparseml/modifiers/quantization/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sparseml.core.event import Event
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifier
from sparseml.modifiers.quantization_legacy import LegacyQuantizationModifier
from tests.sparseml.modifiers.conf import setup_modifier_factory


Expand All @@ -31,14 +31,14 @@ def setUp(self):

def test_quantization_registered(self):
quant_obj = ModifierFactory.create(
type_="QuantizationModifier",
type_="LegacyQuantizationModifier",
framework=Framework.general,
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

self.assertIsInstance(quant_obj, QuantizationModifier)
self.assertIsInstance(quant_obj, LegacyQuantizationModifier)


@pytest.mark.unit
Expand All @@ -52,7 +52,7 @@ def setUp(self):

def test_end_epochs(self):
disable_quant_epoch, freeze_bn_epoch = None, None
obj_modifier = QuantizationModifier(
obj_modifier = LegacyQuantizationModifier(
start=self.start,
scheme=self.scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
Expand All @@ -68,7 +68,7 @@ def test_end_epochs(self):
assert not obj_modifier.check_should_freeze_bn_stats(event)

disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0
obj_modifier = QuantizationModifier(
obj_modifier = LegacyQuantizationModifier(
start=self.start,
scheme=self.scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
Expand Down
Loading

0 comments on commit 1fcec3f

Please sign in to comment.