Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vkuzo committed Feb 14, 2025
2 parents e596007 + c3bb80e commit d49b604
Show file tree
Hide file tree
Showing 20 changed files with 1,092 additions and 619 deletions.
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ For inference, we have the option of
```python
from torchao.quantization.quant_api import (
quantize_,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_weight_only
Int8DynamicActivationInt8WeightConfig,
Int4WeightOnlyConfig,
Int8WeightOnlyConfig
)
quantize_(m, int4_weight_only())
quantize_(m, Int4WeightOnlyConfig())
```

For gpt-fast `int4_weight_only()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.
For gpt-fast `Int4WeightOnlyConfig()` is the best option at bs=1 as it **2x the tok/s and reduces the VRAM requirements by about 65%** over a torch.compiled baseline.

If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, int8_weight_only(), device="cuda")` which will send and quantize each layer individually to your GPU.
If you don't have enough VRAM to quantize your entire model on GPU and you find CPU quantization to be too slow then you can use the device argument like so `quantize_(model, Int8WeightOnlyConfig(), device="cuda")` which will send and quantize each layer individually to your GPU.

If you see slowdowns with any of these techniques or you're unsure which option to use, consider using [autoquant](./torchao/quantization/README.md#autoquantization) which will automatically profile layers and pick the best way to quantize each layer.

Expand All @@ -63,27 +63,27 @@ Post-training quantization can result in a fast and compact model, but may also
```python
from torchao.quantization import (
quantize_,
int8_dynamic_activation_int4_weight,
Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import (
FakeQuantizeConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
)

# Insert fake quantization
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
my_model,
intx_quantization_aware_training(activation_config, weight_config),
IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
)

# Run training... (not shown)

# Convert fake quantization to actual quantized operations
quantize_(my_model, from_intx_quantization_aware_training())
quantize_(my_model, int8_dynamic_activation_int4_weight(group_size=32))
quantize_(my_model, FromIntXQuantizationAwareTrainingConfig())
quantize_(my_model, Int8DynamicActivationInt4WeightConfig(group_size=32))
```

### Float8
Expand Down Expand Up @@ -139,7 +139,7 @@ The best example we have combining the composability of lower bit dtype with com

We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()` so if you love writing kernels but hate packaging them so they work all operating systems and cuda versions, we'd love to accept contributions for your custom ops. We have a few examples you can follow

1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
1. [fp6](torchao/dtypes/floatx) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

Expand Down
40 changes: 33 additions & 7 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
run_tests,
)

from torchao.core.config import AOBaseConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
Expand All @@ -16,6 +17,7 @@
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
Expand Down Expand Up @@ -82,7 +84,8 @@ def test_tensor_core_layout_transpose(self):
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
ql = apply_int4_weight_only_quant(linear)
quantize_(linear, apply_int4_weight_only_quant)
ql = linear
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand All @@ -102,7 +105,12 @@ def test_tensor_core_layout_transpose(self):
)
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
Expand All @@ -115,16 +123,24 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to("cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to(device="cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -181,7 +197,12 @@ def apply_uint6_weight_only_quant(linear):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)


Expand All @@ -195,7 +216,12 @@ def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
Expand Down
6 changes: 3 additions & 3 deletions test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def test_uintx_target_dtype(dtype):

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(linear)
quantize_(linear, uintx_weight_only(dtype))
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))


Expand All @@ -165,7 +165,7 @@ def test_uintx_target_dtype_compile(dtype):

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
# make sure it runs
uintx_weight_only(dtype)(linear)
quantize_(linear, uintx_weight_only(dtype))
linear = torch.compile(linear)
linear(torch.randn(1, 128, dtype=torch.bfloat16, device="cuda"))

Expand Down Expand Up @@ -196,6 +196,6 @@ def test_uintx_model_size(dtype):
)
bf16_size = get_model_size_in_bytes(linear)
# make sure it runs
uintx_weight_only(dtype)(linear[0])
quantize_(linear[0], uintx_weight_only(dtype))
quantized_size = get_model_size_in_bytes(linear)
assert bf16_size * _dtype_to_ratio[dtype] == quantized_size
11 changes: 5 additions & 6 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MappingType,
ZeroPointDomain,
int4_weight_only,
quantize_,
uintx_weight_only,
)
from torchao.utils import (
Expand Down Expand Up @@ -51,13 +52,11 @@ def _eval_hqq(dtype):
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
dummy_linear
).weight
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
else:
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
)(dummy_linear).weight
config = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight

quant_linear_layer = torch.nn.Linear(
W.shape[1], W.shape[0], bias=False, device=W.device
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
def test_quantize_api_standalone(self):
"""
Test that the following:
Expand Down
86 changes: 86 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,41 @@
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.subclass import (
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
is_sm_at_least_90,
unwrap_tensor_subclass,
)

try:
import gemlite # noqa: F401

has_gemlite = True
except ModuleNotFoundError:
has_gemlite = False


def dynamic_quant(model, example_inputs):
m = torch.export.export(model, example_inputs, strict=True).module()
Expand Down Expand Up @@ -783,6 +800,75 @@ def test_int4wo_cpu(self, dtype, x_dim):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"config",
[
int4_weight_only(),
float8_weight_only(),
float8_dynamic_activation_float8_weight(),
float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
int4_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
int8_dynamic_activation_int4_weight(),
int8_weight_only(),
fpx_weight_only(ebits=4, mbits=3),
gemlite_uintx_weight_only(),
uintx_weight_only(dtype=torch.uint4),
],
)
def test_workflow_e2e_numerics(self, config):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
if (
isinstance(
config,
(
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
),
)
and not is_sm_at_least_89()
):
return unittest.skip("requires CUDA capability 8.9 or greater")
elif (
isinstance(config, int4_dynamic_activation_int4_weight)
and is_sm_at_least_90()
):
return unittest.skip("only supported on CUDA capability 8.9, not greater")
elif isinstance(config, gemlite_uintx_weight_only) and not has_gemlite:
return unittest.skip("gemlite not available")

# scale has to be moved to cuda here because the parametrization init
# code happens before gating for cuda availability
if isinstance(config, float8_static_activation_float8_weight):
config.scale = config.scale.to("cuda")

dtype = torch.bfloat16
if isinstance(config, gemlite_uintx_weight_only):
dtype = torch.float16

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=dtype)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().to(dtype)
m_q = copy.deepcopy(m_ref)

# quantize
quantize_(m_q, config)

with torch.no_grad():
y_ref = m_ref(x)
y_q = m_q(x)

sqnr = compute_error(y_ref, y_q)
assert sqnr >= 16.5, f"SQNR {sqnr} is too low"


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
5 changes: 2 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,9 @@ def ffn_or_attn_only(mod, fqn):
else:
quantize_(model, int8_dynamic_activation_int8_weight())
if "int4wo" in quantization:
use_hqq = False
if "hqq" in quantization:
use_hqq = True
else:
use_hqq = False
group_size = int(quantization.split("-")[1])
assert (
group_size
Expand All @@ -434,7 +433,7 @@ def ffn_or_attn_only(mod, fqn):
256,
]
), f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq))
elif "int8adq-int4w-symm" in quantization:
from torchao.dtypes import CutlassInt4PackedLayout

Expand Down
Empty file added torchao/core/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc


class AOBaseConfig(abc.ABC):
"""
If a workflow config inherits from this then `quantize_` knows
how to a apply it to a model. For example::
# user facing code
class WorkflowFooConfig(AOBaseConfig): ...
# configuration for workflow `Foo` is defined here
bar = 'baz'
# non user facing code
@register_quantize_module_handler(WorkflowFooConfig)
def _transform(
mod: torch.nn.Module,
config: WorkflowFooConfig,
) -> torch.nn.Module:
# the transform is implemented here, usually a tensor sublass
# weight swap or a module swap
...
# then, the user calls `quantize_` with a config, and `_transform` is called
# under the hood by `quantize_.
"""

pass
1 change: 1 addition & 0 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _linear_fp8_act_fp8_weight_impl(
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
scaled_mm_config = weight_tensor._layout.mm_config
assert scaled_mm_config is not None
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
Expand Down
Loading

0 comments on commit d49b604

Please sign in to comment.