Skip to content

Commit

Permalink
redo folder structure
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins committed May 28, 2024
1 parent 63a369f commit d83fb2e
Show file tree
Hide file tree
Showing 35 changed files with 470 additions and 470 deletions.
2 changes: 1 addition & 1 deletion src/sparseml/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
from .logarithmic_equalization import *
from .obcq import *
from .pruning import *
from .quantization import *
from .quantization_legacy import *
from .smoothquant import *
107 changes: 26 additions & 81 deletions src/sparseml/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,106 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional

from pydantic import Field

from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
)
from sparseml.core import Event, Modifier


__all__ = ["LegacyQuantizationModifier"]
__all__ = ["QuantizationModifier"]


class LegacyQuantizationModifier(Modifier):
class QuantizationModifier(Modifier):
"""
Enables quantization aware training (QAT) for a given module or its submodules
After the start epoch, the specified module(s) forward pass will emulate
quantized execution and the modifier will be enabled until training is completed.
| Sample yaml:
| LegacyQuantizationModifier:
| start: 0.0
| scheme:
| input_activations:
| num_bits: 8
| symmetric: False
| weights:
| num_bits: 8
| symmetric: True
| scheme_overrides:
| feature_extractor: "default"
| classifier:
| input_activations:
| num_bits: 8
| symmetric: False
| weights: null
| Conv2d:
| input_activations:
| num_bits: 8
| symmetric: True
| ignore: ["ReLU", "input"]
| disable_quantization_observer_epoch: 2.0
| freeze_bn_stats_epoch: 3.0
| model_fuse_fn_name: 'fuse_module'
| strict: True
:param ignore: optional list of module class names or submodule names
to not quantize. Default is None
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.
:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param freeze_bn_stats_epoch: Epoch to stop the tracking of batch norm stats. Leave
None to not stop tracking batch norm stats during QAT. Default is None
:param model_fuse_fn_name: Name of model function to fuse the model in place prior
to performing QAT. Set as None or 'no_fuse' to skip module fusing. Set as
'conv_bv_relus' to use `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
Default is None
:param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
to the model fusing function
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
:param strict: if True, will raise an error if any module types or submodules in
scheme_overrides or ignore are not found in a given module. Default True
"""

ignore: Optional[List[str]] = None
config_groups: Dict[str, QuantizationScheme]
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
freeze_bn_stats_epoch: Optional[float] = None
model_fuse_fn_name: Optional[str] = None
model_fuse_fn_kwargs: Optional[Dict[str, Any]] = None
num_calibration_steps: Optional[int] = None
post_oneshot_calibration: Optional[bool] = False
strict: bool = True

def __init__(self, **kwargs):
super().__init__(**kwargs)
if self.model_fuse_fn_kwargs is None:
self.model_fuse_fn_kwargs = {}
if self.ignore is None:
self.ignore = []

def calculate_freeze_bn_stats_epoch(self) -> float:
"""
Get the epoch at which we want to stop updating batch normalization stats
:return: freeze_bn_stats_epoch if set, else -1
"""
return (
self.freeze_bn_stats_epoch if self.freeze_bn_stats_epoch is not None else -1
def create_init_config(self) -> QuantizationConfig:
return QuantizationConfig(
config_groups=self.config_groups,
quantization_status=QuantizationStatus.INITIALIZED,
ignore=self.ignore,
)

def check_should_freeze_bn_stats(self, event: Event) -> bool:
"""
Given the current index, determine if we should freeze batch normalization stats
:param event: Event to get index from
:return: True if stats should be frozen, False otherwise
"""
freeze_epoch = self.calculate_freeze_bn_stats_epoch()
if freeze_epoch == -1:
return False
if event.current_index >= freeze_epoch:
return True
return False

def calculate_disable_observer_epoch(self) -> float:
"""
Get the epoch at which we want to disable to quantization observer
Expand Down
174 changes: 40 additions & 134 deletions src/sparseml/modifiers/quantization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,77 +13,52 @@
# limitations under the License.

import logging
from typing import Any, Dict, Optional
from typing import Any

import torch
from torch.nn import Module

from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import LegacyQuantizationModifier
from sparseml.modifiers.quantization.modification import modify_model
from sparseml.modifiers.quantization.utils.helpers import (
configure_module_bn_wrappers,
freeze_bn_stats,
fuse_module_conv_bn_relus,
)
from sparseml.modifiers.quantization.utils.quantization_scheme import (
QuantizationScheme,
QuantizationSchemeLoadable,
)
from sparseml.modifiers.quantization.utils.quantize import (
convert_module_qat_from_schemes,
raise_if_torch_quantization_not_available,
set_quantization_schemes,
from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization.base import QuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
from sparseml.utils.fsdp.context import summon_full_params_context


_LOGGER = logging.getLogger(__name__)


class LegacyQuantizationModifierPyTorch(LegacyQuantizationModifier):
class QuantizationModifierPyTorch(QuantizationModifier):
"""
Pytorch-specific implementation of quantization modifier
:param scheme: Default QuantizationScheme to use when enabling quantization
in a module. May also be a dictionary to be loaded into the QuantizationScheme
class. A string alias may also be used, supported aliases:
['default', 'deepsparse', 'tensorrt'].
If None, the default scheme (`QuantizationScheme()`) will be used.
Default is None
:param scheme_overrides: optional mapping of module type names or submodule type
names to quantization schemes to override them with. If a scheme is mapped to
'default', then it will use the scheme set in the modifier scheme property
PyTorch specific implementation of QuantizationModifier
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.
:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

scheme: Optional[QuantizationSchemeLoadable] = None
scheme_overrides: Optional[Dict[str, QuantizationSchemeLoadable]] = None
calibration_dataloader_: Any = None
calibration_function_: Any = None
qat_enabled_: bool = False
quantization_observer_disabled_: bool = False
bn_stats_frozen_: bool = False

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.scheme = QuantizationScheme.load(self.scheme)
self.scheme_overrides = _load_quantization_schemes_dict(
self.scheme_overrides, self.scheme
)

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
# before the structure is modified to support quantization,
# we need to potentially modify the model architecture
module = modify_model(module)
self._enable_module_qat(module)
state.model.model.apply(torch.quantization.disable_observer)
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)

def on_initialize(self, state: State, **kwargs) -> bool:
raise_if_torch_quantization_not_available()
module = state.model.model
module = modify_model(module)
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
Expand All @@ -93,85 +68,39 @@ def on_initialize(self, state: State, **kwargs) -> bool:
self.calibration_dataloader_ = state.data.calib
module = state.model.model

# intialize quantization in appropriate modules
self._apply_modifier_to_model(module)

if self.calculate_start() == -1: # one-shot
self._enable_module_qat(module)
module.apply(set_module_for_calibration)
self._calibrate_if_possible(module)
self._disable_quantization_observer(module)
module.apply(freeze_module_quantization)

return True

def on_finalize(self, state: State, **kwargs) -> bool:
if self.post_oneshot_calibration:
state.model.model.apply(torch.quantization.enable_observer)
self._calibrate_if_possible(state.model.model)
self._disable_quantization_observer(state.model.model)
return True

def on_start(self, state: State, event: Event, **kwargs):
if not self.qat_enabled_:
self._enable_module_qat(state.model.model)
module = state.model.model
module.apply(set_module_for_calibration)

def on_update(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.BATCH_START:
if self.check_should_freeze_bn_stats(event):
self._freeze_bn_stats(state.model.model)
if self.check_should_disable_observer(event):
self._disable_quantization_observer(state.model.model)
module = state.model.model
module.apply(freeze_module_quantization)

def on_end(self, state: State, event: Event, **kwargs):
self._disable_quantization_observer(state.model.model)
module = state.model.model
module.apply(freeze_module_quantization)

def on_event(self, state: State, event: Event, **kwargs):
pass

def _freeze_bn_stats(self, model: Module):
model.apply(freeze_bn_stats)
self.bn_stats_frozen_ = True

def _disable_quantization_observer(self, model: Module):
model.apply(torch.quantization.disable_observer)
self.quantization_observer_disabled_ = True

def _enable_module_qat(self, module: Module):
module.apply(torch.quantization.enable_observer)

if not self.qat_enabled_:
with summon_full_params_context(module):
# fuse conv-bn-relu blocks prior to quantization emulation
self._fuse(module)

# add quantization_schemes to target submodules
set_quantization_schemes(
module,
scheme=self.scheme,
scheme_overrides=self.scheme_overrides,
ignore=self.ignore,
strict=self.strict,
)

# fix for freezing batchnorm statistics when not fusing BN with convs.
# pytorch only supports freezing batchnorm statistics for fused modules.
# this fix wraps BN modules adding with a new module class that supports
# methods related to freezing/unfreezing BN statistics.
configure_module_bn_wrappers(module)

# convert target qconfig layers to QAT modules with FakeQuantize
convert_module_qat_from_schemes(module)

self.qat_enabled_ = True

def _fuse(self, module: Module):
if self.model_fuse_fn_name in [None, "conv_bn_relus"]:
self.model_fuse_fn_kwargs["inplace"] = True
fuse_module_conv_bn_relus(module, **self.model_fuse_fn_kwargs)
elif self.model_fuse_fn_name != "no_fuse":
module_fuse_fn = getattr(module, self.model_fuse_fn_name, None)
if module_fuse_fn is None or not callable(module_fuse_fn):
raise ValueError(
"Invalid model_fuse_fn_name. "
"Module has no callable function {}".format(self.model_fuse_fn_name)
)
module_fuse_fn(**self.model_fuse_fn_kwargs)
def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
apply_quantization_config(model, modifier_as_config)

def _calibrate_if_possible(self, module: Module):
if self.num_calibration_steps == 0 and self.calibration_dataloader_:
Expand Down Expand Up @@ -210,26 +139,3 @@ def _calibrate(self, module: Module):

if module_training:
module.train()
else:
self._disable_quantization_observer(module)


class _QuantizationSchemesDict(dict):
# wrapper class for dict to override the __str__ method for yaml serialization

def __str__(self):
return str({submodule: scheme.dict() for submodule, scheme in self.items()})


def _load_quantization_schemes_dict(
schemes_dict: Optional[Dict[str, QuantizationSchemeLoadable]],
default_scheme: QuantizationScheme,
) -> Dict[str, QuantizationScheme]:
if schemes_dict is None:
return {}
return _QuantizationSchemesDict(
{
submodule: QuantizationScheme.load(scheme, default=default_scheme)
for submodule, scheme in schemes_dict.items()
}
)
Loading

0 comments on commit d83fb2e

Please sign in to comment.