Skip to content

Commit

Permalink
convert old modifier to legacy
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 28, 2024
1 parent c530bbf commit 63a369f
Show file tree
Hide file tree
Showing 28 changed files with 53 additions and 53 deletions.
6 changes: 3 additions & 3 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
from sparseml.core import Event, Modifier


__all__ = ["QuantizationModifier"]
__all__ = ["LegacyQuantizationModifier"]


class QuantizationModifier(Modifier):
class LegacyQuantizationModifier(Modifier):
"""
Enables quantization aware training (QAT) for a given module or its submodules
After the start epoch, the specified module(s) forward pass will emulate
quantized execution and the modifier will be enabled until training is completed.
| Sample yaml:
| QuantizationModifier:
| LegacyQuantizationModifier:
| start: 0.0
| scheme:
| input_activations:
Expand Down
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def _build_quant_modifier(self, framework):
)
quant_args["config_groups"] = {"config_group_0": default_quant_scheme}
_LOGGER.info(f"Building quantization modifier with args: {quant_args}")
vllm_quant_config = {"vLLMQuantizationModifier": quant_args}
vllm_quant_config = {"QuantizationModifier": quant_args}
self._build_quant_modifier_from_dict(vllm_quant_config, framework)

def compressible_layers(self) -> Dict:
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.nn import Module

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.quantization.base import LegacyQuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
Expand All @@ -42,7 +42,7 @@
_LOGGER = logging.getLogger(__name__)


class QuantizationModifierPyTorch(QuantizationModifier):
class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier):
"""
Pytorch-specific implementation of quantization modifier
Expand Down
4 changes: 2 additions & 2 deletions src/sparseml/modifiers/quantization_vllm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from sparseml.core import Event, Modifier


__all__ = ["vLLMQuantizationModifier"]
__all__ = ["QuantizationModifier"]


class vLLMQuantizationModifier(Modifier):
class QuantizationModifier(Modifier):
"""
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
Expand Down
6 changes: 3 additions & 3 deletions src/sparseml/modifiers/quantization_vllm/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.quantization_vllm.base import QuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)


class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):
class QuantizationModifierPyTorch(QuantizationModifier):
"""
PyTorch specific implementation of vLLMQuantizationModifier
PyTorch specific implementation of QuantizationModifier
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def save_pretrained_wrapper(
# check if we are in the old quantization framework
if qat_active(model) and not is_model_quantized(model):
_LOGGER.info(
"Compression for models quantized with QuantizationModifer is not "
"supported. Save will be run without compression and no sparsity "
"statistics will be calculated. To save a quantized model in a "
"compressed state please use vLLMQuantizationModifier instead."
"Compression for models quantized with LegacyQuantizationModifer "
"is not supported. Save will be run without compression and no "
"sparsity statistics will be calculated. To save a quantized model "
"in a compressed state please use QuantizationModifier instead."
)

original_save_pretrained.__get__(model, model_class)(
Expand Down
10 changes: 5 additions & 5 deletions tests/sparseml/modifiers/quantization/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sparseml.core.event import Event
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization import QuantizationModifier
from sparseml.modifiers.quantization import LegacyQuantizationModifier
from tests.sparseml.modifiers.conf import setup_modifier_factory


Expand All @@ -31,14 +31,14 @@ def setUp(self):

def test_quantization_registered(self):
quant_obj = ModifierFactory.create(
type_="QuantizationModifier",
type_="LegacyQuantizationModifier",
framework=Framework.general,
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

self.assertIsInstance(quant_obj, QuantizationModifier)
self.assertIsInstance(quant_obj, LegacyQuantizationModifier)


@pytest.mark.unit
Expand All @@ -52,7 +52,7 @@ def setUp(self):

def test_end_epochs(self):
disable_quant_epoch, freeze_bn_epoch = None, None
obj_modifier = QuantizationModifier(
obj_modifier = LegacyQuantizationModifier(
start=self.start,
scheme=self.scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
Expand All @@ -68,7 +68,7 @@ def test_end_epochs(self):
assert not obj_modifier.check_should_freeze_bn_stats(event)

disable_quant_epoch, freeze_bn_epoch = 3.5, 5.0
obj_modifier = QuantizationModifier(
obj_modifier = LegacyQuantizationModifier(
start=self.start,
scheme=self.scheme,
disable_quantization_observer_epoch=disable_quant_epoch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from sparseml.core.model import ModifiableModel
from sparseml.modifiers.obcq.pytorch import SparseGPTModifierPyTorch
from sparseml.modifiers.quantization.gptq.pytorch import GPTQModifierPyTorch
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.quantization.pytorch import LegacyQuantizationModifierPyTorch
from sparseml.modifiers.quantization_vllm.base import QuantizationModifier
from tests.sparseml.modifiers.conf import LifecyleTestingHarness, setup_modifier_factory
from tests.sparseml.pytorch.helpers import LinearNet
from tests.testing_utils import requires_torch
Expand Down Expand Up @@ -92,13 +92,13 @@ def test_create_default_quant_modifier(self):
testing_harness = LifecyleTestingHarness(model=LinearNet())
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
assert isinstance(modifier.quantization_modifier_, vLLMQuantizationModifier)
assert isinstance(modifier.quantization_modifier_, QuantizationModifier)
default_config_group_name = "config_group_0"
should_be_default_quant_scheme = modifier.quantization_modifier_.config_groups[
default_config_group_name
]
self.assertEqual(should_be_default_quant_scheme.input_activations.num_bits, 8)
# input activations are symmetric by default in vLLMQuantizationModifier
# input activations are symmetric by default in QuantizationModifier
assert should_be_default_quant_scheme.input_activations.symmetric

self.assertEqual(should_be_default_quant_scheme.weights.num_bits, 8)
Expand All @@ -120,7 +120,7 @@ def test_set_quant_if_modifer_already_exists(self):
),
)

modifier = QuantizationModifierPyTorch(**kwargs)
modifier = LegacyQuantizationModifierPyTorch(**kwargs)
testing_harness = LifecyleTestingHarness(model=model, start=-1)

assert not testing_harness.get_state().model.qat_active()
Expand Down Expand Up @@ -159,7 +159,7 @@ def setUp(self):
}
}
}
self.quant_config = {"vLLMQuantizationModifier": self.quant_kwargs}
self.quant_config = {"QuantizationModifier": self.quant_kwargs}

def test_set_quant_in_gptq(self):
kwargs = dict(block_size=128, quantize=self.quant_config)
Expand All @@ -170,7 +170,7 @@ def test_set_quant_in_gptq(self):
testing_harness = LifecyleTestingHarness(model=LinearNet())
modifier.on_initialize_structure(testing_harness.get_state())
assert modifier.quantize
self.assertIsInstance(modifier.quantization_modifier_, vLLMQuantizationModifier)
self.assertIsInstance(modifier.quantization_modifier_, QuantizationModifier)

dict_scheme = dict(modifier.quantization_modifier_.config_groups)
self._check_config(
Expand Down
10 changes: 5 additions & 5 deletions tests/sparseml/pytorch/modifiers/quantization/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sparseml.core.event import Event, EventType
from sparseml.core.factory import ModifierFactory
from sparseml.core.framework import Framework
from sparseml.modifiers.quantization.pytorch import QuantizationModifierPyTorch
from sparseml.modifiers.quantization.pytorch import LegacyQuantizationModifierPyTorch
from sparseml.pytorch.sparsification.quantization.quantize import (
is_qat_helper_module,
is_quantizable_module,
Expand All @@ -45,14 +45,14 @@ def setUp(self):

def test_quantization_registered(self):
quant_obj = ModifierFactory.create(
type_="QuantizationModifier",
type_="LegacyQuantizationModifier",
framework=Framework.pytorch,
allow_experimental=False,
allow_registered=True,
**self.kwargs,
)

self.assertIsInstance(quant_obj, QuantizationModifierPyTorch)
self.assertIsInstance(quant_obj, LegacyQuantizationModifierPyTorch)


@pytest.mark.unit
Expand All @@ -71,7 +71,7 @@ def test_quantization_oneshot(self, model_class):
state = State(framework=Framework.pytorch, start_event=Event())
state.update(model=model, start=-1)

modifier = QuantizationModifierPyTorch(**self.kwargs)
modifier = LegacyQuantizationModifierPyTorch(**self.kwargs)

modifier.initialize(state)

Expand Down Expand Up @@ -108,7 +108,7 @@ def setUp(self):
def test_quantization_training(self, model_class):
model = model_class()

modifier = QuantizationModifierPyTorch(**self.kwargs)
modifier = LegacyQuantizationModifierPyTorch(**self.kwargs)

testing_harness = LifecyleTestingHarness(model=model)
modifier.initialize(testing_harness.get_state())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
vLLMQuantizationModifier:
QuantizationModifier:
ignore: ["lm_head", "model.layers.0.mlp.down_proj"]
config_groups:
group_0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
vLLMQuantizationModifier:
QuantizationModifier:
ignore: ["lm_head", "model.layers.0.mlp.down_proj"]
config_groups:
group_0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
vLLMQuantizationModifier:
QuantizationModifier:
ignore: ["lm_head"]
config_groups:
group_0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
vLLMQuantizationModifier:
QuantizationModifier:
ignore: ["lm_head", "model.layers.0.mlp.down_proj"]
config_groups:
group_0:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- model.layers.0.mlp.down_proj
- lm_head
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- model.layers.0.mlp.down_proj
- lm_head
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- model.layers.0.mlp.down_proj
- lm_head
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
test_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dataset: open_platypus
first_recipe: |
first_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand All @@ -17,7 +17,7 @@ first_recipe: |
second_recipe: |
second_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dataset: open_platypus
first_recipe: |
first_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand All @@ -17,7 +17,7 @@ first_recipe: |
second_recipe: |
second_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_stage:
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand Down
2 changes: 1 addition & 1 deletion tests/sparseml/transformers/obcq/recipes/quant.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_stage:
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_stage:
[["re:.*q_proj", "re:.*k_proj", "re:.*v_proj"], "re:.*input_layernorm"],
[["re:.*gate_proj", "re:.*up_proj"], "re:.*post_attention_layernorm"]
]
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- LlamaRotaryEmbedding
- LlamaRMSNorm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setUp(self):
self.recipe = """
first_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- Embedding
scheme_overrides:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def llama_recipe():
return """test_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- MatMulRightInput_QK
- MatMulLeftInput_QK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def mistral_recipe():
return """test_stage:
quant_modifiers:
QuantizationModifier:
LegacyQuantizationModifier:
ignore:
- MatMulRightInput_QK
- MatMulLeftInput_QK
Expand Down
Loading

0 comments on commit 63a369f

Please sign in to comment.