Skip to content

Commit

Permalink
Dyanmic Quantization (#15)
Browse files Browse the repository at this point in the history
* [WIP] Dyanmic Quantization

* update imports post rename

* update dynamic bool

* move dynamic control to Quant Args

* Apply suggestions from code review

* docstring and test
  • Loading branch information
bfineran authored Apr 25, 2024
1 parent dd2bd7f commit d707c5b
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 27 deletions.
30 changes: 18 additions & 12 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def wrapped_forward(self, *args, **kwargs):


def _maybe_calibrate_or_quantize(
module: Module, value: Module, base_name: str, args: "QuantizationArgs"
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
) -> torch.Tensor:
# only run quantized for the included stages
if module.quantization_status not in {
Expand All @@ -120,17 +120,23 @@ def _maybe_calibrate_or_quantize(
}:
return value

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# get observer and get new quant params from observation
if args.dynamic:
# dynamic quantization - get scale and zero point directly from observer
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)
scale, zero_point = observer(value)
else:
# static quantization - get previous scale and zero point from layer
scale = getattr(module, f"{base_name}_scale")
zero_point = getattr(module, f"{base_name}_zero_point")

if module.quantization_status == QuantizationStatus.CALIBRATION:
# calibration mode - get new quant params from observer
observer = getattr(module, f"{base_name}_observer")
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
device = next(module.parameters()).device
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
18 changes: 9 additions & 9 deletions src/compressed_tensors/quantization/lifecycle/frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ def freeze_module_quantization(module: Module):
:param module: module to freeze quantization for
"""
if not getattr(module, "quantization_scheme", None):
scheme = getattr(module, "quantization_scheme", None)
if not scheme:
# no quantization scheme nothing to do
return

# delete observers from module
observer_names = []
for submodule_name, _ in module.named_modules():
if "." not in submodule_name and submodule_name.endswith("_observer"):
# delete any observers that belong directly to this module
observer_names.append(submodule_name)
for observer_name in observer_names:
delattr(module, observer_name)
# delete observers from module if not dynamic
if scheme.input_activations and not scheme.input_activations.dynamic:
delattr(module, "input_observer")
if scheme.weights and not scheme.weights.dynamic:
delattr(module, "weight_observer")
if scheme.output_activations and not scheme.output_activations.dynamic:
delattr(module, "output_observer")

module.quantization_status = QuantizationStatus.FROZEN
11 changes: 7 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ def initialize_module_for_quantization(
def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)

if quantization_args.dynamic:
return # no need to register a scale and zero point for a dynamic observer

device = next(module.parameters()).device

# initializes empty scale and zero point parameters for the module
Expand All @@ -90,7 +97,3 @@ def _initialize_scale_zero_point_observer(
torch.empty(0, device=device, dtype=int), requires_grad=False
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
observer = quantization_args.get_observer()
module.register_module(f"{base_name}_observer", observer)
4 changes: 2 additions & 2 deletions src/compressed_tensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
__all__ = ["MemorylessObserver"]


@Observer.register("memoryless")
@Observer.register("memoryless", alias=["dynamic"])
class MemorylessObserver(Observer):
"""
Implements a dynamic quantization observer that sets the scale and
Implements a quantization observer that sets the scale and
zero point based on the latest observed value without tracking state
"""

Expand Down
11 changes: 11 additions & 0 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class QuantizationArgs(BaseModel):
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block strategy, must be
of the format "2x4", "8x16", etc.
:param dynamic: set True to perform dynamic quantization - values will not be
calibrated during calibration phase, instead during inference new quantization
ranges will be observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization will change the default
observer to a memoryless one
"""

num_bits: int = 8
Expand All @@ -61,6 +66,7 @@ class QuantizationArgs(BaseModel):
strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
group_size: Optional[int] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
default="minmax",
description=(
Expand All @@ -82,4 +88,9 @@ def get_observer(self):
"""
from compressed_tensors.quantization.observers.base import Observer

if self.observer == "minmax" and self.dynamic:
# override defualt observer for dynamic, you never want minmax which
# keeps state across samples for dynamic
self.observer = "memoryless"

return Observer.load_from_registry(self.observer, quantization_args=self)
117 changes: 117 additions & 0 deletions tests/quantization/lifecycle/test_dynamic_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from compressed_tensors.quantization.lifecycle import (
apply_quantization_config,
freeze_module_quantization,
)
from compressed_tensors.quantization.quant_config import QuantizationConfig
from transformers import AutoModelForCausalLM


def test_apply_tinyllama_dynamic_activations():
quant_config = get_sample_dynamic_tinyllama_quant_config()
model = get_tinyllama_model()

# check that model is not already quantized
for module in model.modules():
_test_layer_dynamic_quantization_status(module, inputs=False, weights=False)

# apply quant config to model
apply_quantization_config(model, quant_config)

# test linears are dynamically quantized for calibration
_test_linears_dynamic_quantization_status(model, quant_config, frozen=False)
# verify forward works w/ dynamic during calibration
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))

# freeze and test that only weight observers are deleted
model.apply(freeze_module_quantization)
_test_linears_dynamic_quantization_status(model, quant_config, frozen=True)
# verify forward works w/ dynamic after freeze
model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int))


def _test_linears_dynamic_quantization_status(model, quant_config, frozen: bool):
# check for correct application of quant config
num_linears = 0
for name, module in model.named_modules():
if name in quant_config.ignore:
continue
module_type = module.__class__.__name__
if module_type == "Linear":
num_linears += 1
_test_layer_dynamic_quantization_status(
module, inputs=True, weights=True, frozen=frozen
)

# sanity check correct number of layers targeted
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored


def _test_layer_dynamic_quantization_status(
module, inputs: bool, weights: bool, frozen: bool = False
):
# check if quantization is applied at all (true if inputs or weights targeted)
quantized = inputs or weights
assert hasattr(module, "quantization_scheme") == quantized
assert hasattr(module, "quantization_status") == quantized

# check inputs always have an observer if quantized but never scale/zp
assert not hasattr(module, "input_scale")
assert not hasattr(module, "input_zero_point")
assert hasattr(module, "input_observer") == inputs

# check weights always have scale/zp and observer only if not frozen
assert hasattr(module, "weight_scale") == weights
assert hasattr(module, "weight_zero_point") == weights
assert hasattr(module, "weight_observer") == (weights and not frozen)


def get_tinyllama_model():
return AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
)


def get_sample_dynamic_tinyllama_quant_config():
config_dict = {
"quant_method": "sparseml",
"format": "fakequant",
"quantization_status": "calibration",
"global_compression_ratio": None,
"config_groups": {
"group_1": {
"weights": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
"dynamic": False,
},
"input_activations": {
"num_bits": 8,
"type": "int",
"symmetric": True,
"strategy": "tensor",
"dynamic": True,
},
"targets": ["Linear"],
},
},
"ignore": ["LlamaRotaryEmbedding", "model.layers.1.mlp.down_proj"],
}
return QuantizationConfig.parse_obj(config_dict)

0 comments on commit d707c5b

Please sign in to comment.