Skip to content

Commit

Permalink
Refactor Quantization Modifer and Reloading (#2246)
Browse files Browse the repository at this point in the history
* initial commit

* update setup.py

* Update setup.py

* fix setup.py

* move all config to sparsetensors

* cleanup class name and comments

* initial implementation untested

* fixing issues

* add test script

* update perplexity test

* refactor to compressed-tensors

* rename sparsetensors

* update setup

* Sa/model reload (#2250)

* working reload

* sparsegpt

* cleanup

* refactor tests

* only run oneshot once

* all tests passing

* remove unused config

* reset models on each parameterize

* style

* bring back SparsityConfigMetadata

* Update setup.py

Co-authored-by: Rahul Tuli <[email protected]>

* add more comparisons, tighten threshold

* use wikitext for perplexity

* update setup

* fix import problem

* fix clearml test

* compressed-tensors are transformers dep

* address PR comments

* can't repeat freeze

* UX pr comments

* quality

* shape consistency

* address PR comments

---------

Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: dbogunowicz <[email protected]>
Co-authored-by: Rahul Tuli <[email protected]>
Co-authored-by: George Ohashi <[email protected]>
  • Loading branch information
5 people authored May 6, 2024
1 parent 1bad1fb commit f7cb678
Show file tree
Hide file tree
Showing 16 changed files with 732 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
enable_cpu_affinity: false
gpu_ids: 0
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
23 changes: 23 additions & 0 deletions src/sparseml/modifiers/obcq/utils/sgpt_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,29 @@ def fasterprune(
else:
q = torch.quantize_per_channel(q, scale, zero_point, 0, dtype)
q = torch.dequantize(q)
elif hasattr(self.layer, "quantization_scheme"):
if self.layer.quantization_scheme.weights is not None:
scale = self.layer.weight_scale
zero_point = self.layer.weight_zero_point
from compressed_tensors.quantization.lifecycle.forward import (
fake_quantize,
)

while scale.ndim < 2:
scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)

while q.ndim < 2:
q = q.unsqueeze(1)
q = fake_quantize(
q,
scale[:, i],
zero_point[:, i],
self.layer.quantization_scheme.weights,
)

while q.ndim != 1:
q.squeeze()

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
Expand Down
17 changes: 17 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

from .base import *
83 changes: 83 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional

from pydantic import Field

from compressed_tensors.quantization import (
QuantizationConfig,
QuantizationScheme,
QuantizationStatus,
)
from sparseml.core import Event, Modifier


__all__ = ["vLLMQuantizationModifier"]


class vLLMQuantizationModifier(Modifier):
"""
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.
:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

config_groups: Dict[str, QuantizationScheme]
ignore: List[str] = Field(default_factory=list)
disable_quantization_observer_epoch: Optional[float] = None
num_calibration_steps: Optional[int] = None

def create_init_config(self) -> QuantizationConfig:
return QuantizationConfig(
config_groups=self.config_groups,
quantization_status=QuantizationStatus.INITIALIZED,
ignore=self.ignore,
)

def calculate_disable_observer_epoch(self) -> float:
"""
Get the epoch at which we want to disable to quantization observer
:return epoch to disable at, or -1 if it is not set
"""
return (
self.disable_quantization_observer_epoch
if self.disable_quantization_observer_epoch is not None
else -1
)

def check_should_disable_observer(self, event: Event) -> bool:
"""
Given the current index, determine if we should disable the observer
:param event: Event to get index from
:return: True if observer should be disabled, False otherwise
"""
disable_epoch = self.calculate_disable_observer_epoch()
if disable_epoch == -1:
return False
if event.current_index >= disable_epoch:
return True
return False
141 changes: 141 additions & 0 deletions src/sparseml/modifiers/quantization_vllm/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import Any

from torch.nn import Module

from compressed_tensors.quantization import (
apply_quantization_config,
freeze_module_quantization,
set_module_for_calibration,
)
from sparseml.core import Event, EventType, State
from sparseml.modifiers.quantization_vllm.base import vLLMQuantizationModifier
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward


_LOGGER = logging.getLogger(__name__)


class vLLMQuantizationModifierPyTorch(vLLMQuantizationModifier):
"""
PyTorch specific implementation of vLLMQuantizationModifier
Enables post training quantization (PTQ) and quantization aware training (QAT) for a
given module or its submodules. After calibration (PTQ) or the start epoch (QAT),
the specified module(s) forward pass will emulate quantized execution and the
modifier will be enabled until training is completed.
:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
:param disable_quantization_observer_epoch: Epoch to disable updates to the module
quantization observers. At this point, quantized weights and zero points will
not be updated. Leave None to not disable observers during QAT. Default is None
:param num_calibration_steps: Number of steps to run post training calibration for.
When None, the entire calibration_dataloader is used
"""

calibration_dataloader_: Any = None
calibration_function_: Any = None

def on_initialize_structure(self, state: State, **kwargs):
module = state.model.model
self._apply_modifier_to_model(module)
module.apply(freeze_module_quantization)

def on_initialize(self, state: State, **kwargs) -> bool:
if self.end and self.end != -1:
raise ValueError(
"end_epoch is disabled for QuantizationModifier and can only be set to"
" -1 or None. Given {}".format(self.end)
)

self.calibration_dataloader_ = state.data.calib
module = state.model.model

# intialize quantization in appropriate modules
self._apply_modifier_to_model(module)

if self.calculate_start() == -1: # one-shot
module.apply(set_module_for_calibration)
self._calibrate_if_possible(module)
module.apply(freeze_module_quantization)

return True

def on_finalize(self, state: State, **kwargs) -> bool:
return True

def on_start(self, state: State, event: Event, **kwargs):
module = state.model.model
module.apply(set_module_for_calibration)

def on_update(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.BATCH_START:
if self.check_should_disable_observer(event):
module = state.model.model
module.apply(freeze_module_quantization)

def on_end(self, state: State, event: Event, **kwargs):
module = state.model.model
module.apply(freeze_module_quantization)

def on_event(self, state: State, event: Event, **kwargs):
pass

def _apply_modifier_to_model(self, model: Module):
modifier_as_config = self.create_init_config()
apply_quantization_config(model, modifier_as_config)

def _calibrate_if_possible(self, module: Module):
if self.num_calibration_steps == 0 and self.calibration_dataloader_:
_LOGGER.warning(
f"num_calibration_steps is {self.num_calibration_steps}."
f"Calibration data loader will not be used."
)
elif self.num_calibration_steps and not self.calibration_dataloader_:
raise ValueError(
f"num_calibration_steps is {self.num_calibration_steps}. "
"Calibration data loader is not set. Pass a "
"calibration_data_loader with initialize(...) method."
)

elif not self.calibration_dataloader_:
return

self._calibrate(module)

def _calibrate(self, module: Module):
class_name = self.__class__.__name__.replace("PyTorch", "")
_LOGGER.info(
f"Running {class_name} calibration with "
f"{len(self.calibration_dataloader_)} samples..."
)

module_training = module.training
module.eval()

run_calibration_forward(
module,
self.calibration_dataloader_,
self.num_calibration_steps,
self.calibration_function_,
)

if module_training:
module.train()
2 changes: 1 addition & 1 deletion src/sparseml/transformers/compression/sparsity_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def infer_sparsity_structure() -> str:
return sparsity_structure

@staticmethod
def infer_config_from_model(
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
compress: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@
from transformers import PreTrainedModel
from transformers.file_utils import CONFIG_NAME

from compressed_tensors import SPARSITY_CONFIG_NAME, CompressionConfig, ModelCompressor
from compressed_tensors import (
QUANTIZATION_CONFIG_NAME,
SPARSITY_CONFIG_NAME,
CompressionConfig,
ModelCompressor,
QuantizationConfig,
)
from compressed_tensors.quantization.utils import is_model_quantized
from sparseml.transformers.compression.sparsity_config import SparsityConfigMetadata
from sparseml.utils.pytorch import qat_active

Expand Down Expand Up @@ -76,16 +83,45 @@ def save_pretrained_wrapper(
# state_dict gets passed in as a kwarg for FSDP models
state_dict = kwargs.get("state_dict", None)

if qat_active(model):
# check if we are in the old quantization framework
if qat_active(model) and not is_model_quantized(model):
_LOGGER.info(
"Compression for quantized models is not yet supported. Save will "
"be run without compression and no sparsity statistics will be "
"calculated."
"Compression for models quantized with QuantizationModifer is not "
"supported. Save will be run without compression and no sparsity "
"statistics will be calculated. To save a quantized model in a "
"compressed state please use vLLMQuantizationModifier instead."
)
return original_save_pretrained.__get__(model, model_class)(

original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)

return

elif qat_active(model): # quantized in new framework
_LOGGER.info(
"Sparsity compression for quantized models is not yet supported. "
"No sparsity statistics will be calculated and no sparsity config "
"will be saved."
)

original_save_pretrained.__get__(model, model_class)(
save_directory, **kwargs
)

quant_config = QuantizationConfig.from_pretrained(model)
quant_config_data = quant_config.model_dump(exclude_unset=True)
config_file_path = os.path.join(save_directory, CONFIG_NAME)

# add the sparsity config to the model's config file
with open(config_file_path, "r") as config_file:
config_data = json.load(config_file)
config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
with open(config_file_path, "w") as config_file:
json.dump(config_data, config_file, indent=2, sort_keys=True)

return

if sparsity_config is not None:
sparsity_config.global_sparsity = (
SparsityConfigMetadata.infer_global_sparsity(
Expand All @@ -104,7 +140,7 @@ def save_pretrained_wrapper(
"calculation of compression statistics set "
"skip_compression_stats=True"
)
sparsity_config = SparsityConfigMetadata.infer_config_from_model(
sparsity_config = SparsityConfigMetadata.from_pretrained(
model, state_dict=state_dict, compress=save_compressed
)

Expand Down
Loading

0 comments on commit f7cb678

Please sign in to comment.