From 6512a222f2ef643dfa16fe45dab6e96c42f17158 Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 16 Feb 2024 14:23:05 +0100 Subject: [PATCH] feat(quanto): introduce qtype This implies a lot of modifications but is functionally equivalent. --- bench/generation/benchmark.py | 6 +- .../sst2/quantize_sst2_model.py | 4 +- .../quantize_causal_lm_model.py | 4 +- .../mnist/quantize_mnist_model.py | 4 +- quanto/nn/qlinear.py | 16 +- quanto/nn/qmodule.py | 4 +- quanto/tensor/__init__.py | 1 + quanto/tensor/core.py | 106 ++++++------ quanto/tensor/ops.py | 97 +++++++---- quanto/tensor/qtype.py | 34 ++++ test/helpers.py | 8 +- test/library/test_unpack.py | 4 +- test/model/test_quantize_mlp.py | 14 +- test/nn/test_calibrate.py | 18 +-- test/nn/test_custom_qmodule.py | 153 ------------------ test/nn/test_qattention.py | 16 +- test/nn/test_qlayernorm.py | 18 +-- test/nn/test_qlinear.py | 38 ++--- test/nn/test_qmodule.py | 6 +- test/tensor/ops/test_linear_dispatch.py | 4 +- test/tensor/ops/test_quantized_dispatch.py | 4 +- test/tensor/test_absmax.py | 12 +- test/tensor/test_qbitstensor.py | 38 ++--- test/tensor/test_qtensor.py | 40 ++--- test/test_serialization.py | 26 +-- 25 files changed, 286 insertions(+), 389 deletions(-) create mode 100644 quanto/tensor/qtype.py delete mode 100644 test/nn/test_custom_qmodule.py diff --git a/bench/generation/benchmark.py b/bench/generation/benchmark.py index 3aa34e60..d18765f1 100644 --- a/bench/generation/benchmark.py +++ b/bench/generation/benchmark.py @@ -7,7 +7,7 @@ from tqdm.auto import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig -from quanto import Calibration, freeze, quantize +from quanto import Calibration, freeze, qint8, quantize CALIBRATION_PROMPT = "It was a bright cold day in April, and the clocks were striking thirteen." @@ -165,8 +165,8 @@ def main(): if args.quantization in ("w8a8", "w8a16"): print("quantizing") start = time.time() - weights = torch.int8 - activations = None if "a16" in args.quantization else torch.int8 + weights = qint8 + activations = None if "a16" in args.quantization else qint8 quantize(model, weights=weights, activations=activations) if activations is not None: print("Calibrating") diff --git a/examples/nlp/text-classification/sst2/quantize_sst2_model.py b/examples/nlp/text-classification/sst2/quantize_sst2_model.py index 577e2152..fdfc7620 100644 --- a/examples/nlp/text-classification/sst2/quantize_sst2_model.py +++ b/examples/nlp/text-classification/sst2/quantize_sst2_model.py @@ -8,7 +8,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline from transformers.pipelines.pt_utils import KeyDataset -from quanto import Calibration, freeze, quantize +from quanto import Calibration, freeze, qint8, quantize def evaluate_model(model, tokenizer, dataset, device, batch_size): @@ -22,7 +22,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size): def keyword_to_itype(k): - return {"none": None, "int8": torch.int8}[k] + return {"none": None, "int8": qint8}[k] def main(): diff --git a/examples/nlp/text-generation/quantize_causal_lm_model.py b/examples/nlp/text-generation/quantize_causal_lm_model.py index 668bacc6..f6e4d81e 100644 --- a/examples/nlp/text-generation/quantize_causal_lm_model.py +++ b/examples/nlp/text-generation/quantize_causal_lm_model.py @@ -5,7 +5,7 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer -from quanto import Calibration, freeze, quantize +from quanto import Calibration, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize @torch.no_grad() @@ -51,7 +51,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size, samples=None, def keyword_to_itype(k): - return {"none": None, "int8": torch.int8, "fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn}[k] + return {"none": None, "int8": qint8, "fp8_e5m2": qfloat8_e5m2, "fp8_e4m3": qfloat8_e4m3fn}[k] def main(): diff --git a/examples/vision/image-classification/mnist/quantize_mnist_model.py b/examples/vision/image-classification/mnist/quantize_mnist_model.py index 0a550919..0d891f1f 100644 --- a/examples/vision/image-classification/mnist/quantize_mnist_model.py +++ b/examples/vision/image-classification/mnist/quantize_mnist_model.py @@ -7,7 +7,7 @@ from torchvision import datasets, transforms from transformers import AutoModel -from quanto import Calibration, QTensor, freeze, int4, quantize +from quanto import Calibration, QTensor, freeze, qint4, qint8, quantize def test(model, device, test_loader): @@ -60,7 +60,7 @@ def train(log_interval, model, device, train_loader, optimizer, epoch): def keyword_to_itype(k): - return {"none": None, "int4": int4, "int8": torch.int8}[k] + return {"none": None, "int4": qint4, "int8": qint8}[k] def main(): diff --git a/quanto/nn/qlinear.py b/quanto/nn/qlinear.py index bd83f3c1..f4898bcf 100644 --- a/quanto/nn/qlinear.py +++ b/quanto/nn/qlinear.py @@ -2,7 +2,7 @@ import torch -from ..tensor import QBitsTensor, QTensor, absmax_scale, qbitsdtype +from ..tensor import QBitsTensor, QTensor, absmax_scale, qint2, qint4, qint8, qtype from .qmodule import QModuleMixin, register_qmodule @@ -11,12 +11,12 @@ @register_qmodule(torch.nn.Linear) class QLinear(QModuleMixin, torch.nn.Linear): - def __init__(self, *args, weights: torch.dtype = torch.int8, **kwargs): + def __init__(self, *args, weights: qtype = qint8, **kwargs): super().__init__(*args, **kwargs) self.weights = weights @classmethod - def from_module(cls, module, weights=torch.int8, activations: Optional[torch.dtype] = None): + def from_module(cls, module, weights=qint8, activations: Optional[qtype] = None): qmodule = cls( module.in_features, module.out_features, @@ -36,12 +36,12 @@ def qweight(self): if isinstance(self.weight, QTensor): return self.weight # Quantize the weights per-axis - if isinstance(self.weights, torch.dtype): + if self.weights == qint8: wscale = absmax_scale(self.weight, axis=0) - return QTensor.quantize(self.weight, itype=self.weights, scale=wscale) - elif isinstance(self.weights, qbitsdtype): - return QBitsTensor.quantize(self.weight, itype=self.weights, axis=0) - raise ValueError("Invalid quantized weights type") + return QTensor.quantize(self.weight, qtype=self.weights, scale=wscale) + elif self.weights in (qint2, qint4): + return QBitsTensor.quantize(self.weight, qtype=self.weights, axis=0) + raise ValueError(f"Invalid quantized weights type {self.weights}") def qforward(self, input: torch.Tensor) -> torch.Tensor: if self.activations is not None and not isinstance(input, QTensor): diff --git a/quanto/nn/qmodule.py b/quanto/nn/qmodule.py index f1cf4fb3..62ee603a 100644 --- a/quanto/nn/qmodule.py +++ b/quanto/nn/qmodule.py @@ -94,9 +94,9 @@ def qforward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor: def maybe_requantize(t, scale): - if t.itype == self.activations and t.axis is None: + if t.qtype == self.activations and t.axis is None: return t - return QTensor.quantize(t.dequantize(), itype=self.activations, scale=scale) + return QTensor.quantize(t.dequantize(), qtype=self.activations, scale=scale) if self.activations is not None and isinstance(input, QTensor): input = maybe_requantize(input, self.input_scale) diff --git a/quanto/tensor/__init__.py b/quanto/tensor/__init__.py index bb67a43f..af95387a 100644 --- a/quanto/tensor/__init__.py +++ b/quanto/tensor/__init__.py @@ -1 +1,2 @@ from .core import * +from .qtype import * diff --git a/quanto/tensor/core.py b/quanto/tensor/core.py index 1a12d27e..b834dcf5 100644 --- a/quanto/tensor/core.py +++ b/quanto/tensor/core.py @@ -1,25 +1,13 @@ -from dataclasses import dataclass from typing import Optional import torch from torch.autograd import Function from torch.utils import _pytree as pytree +from .qtype import qint2, qint4, qint8, qtype -__all__ = ["absmax_scale", "int2", "int4", "qbitsdtype", "qfallback", "dtype_info", "QBitsTensor", "QTensor"] - -@dataclass -class qbitsdtype: - """A dtype class mimicking torch dtype""" - - is_complex: bool - is_floating_point: bool - bits: int - - -int2 = qbitsdtype(is_complex=False, is_floating_point=False, bits=2) -int4 = qbitsdtype(is_complex=False, is_floating_point=False, bits=4) +__all__ = ["absmax_scale", "qfallback", "dtype_info", "QBitsTensor", "QTensor"] def dtype_info(dtype): @@ -36,7 +24,7 @@ def axis_to_dim(t, axis): return dim -def pack_weights(intweights: torch.Tensor, bitsdtype: qbitsdtype) -> torch.Tensor: +def pack_weights(intweights: torch.Tensor, qtype: qtype) -> torch.Tensor: """ Pack int4 / int2 weights in a unint8 tensor @@ -54,10 +42,10 @@ def pack_weights(intweights: torch.Tensor, bitsdtype: qbitsdtype) -> torch.Tenso Args: intweights (`torch.Tensor`): The un-packed tensor in `torch.int8` precision - bitsdtype (`quanto.qbitsdtype`): - The desired `bitsdtype` - can be `quanto.int2`, `quanto.int4` + qtype (`quanto.qtype`): + The desired `qtype` - can be `quanto.int2`, `quanto.int4` """ - bits = bitsdtype.bits + bits = qtype.bits original_shape = intweights.shape values_per_item = 8 // bits if original_shape[0] % values_per_item != 0: @@ -85,7 +73,7 @@ def lshift(t: torch.Tensor, bits: int): return packed -def unpack_weights(uint8weights: torch.Tensor, bitsdtype: qbitsdtype) -> torch.Tensor: +def unpack_weights(uint8weights: torch.Tensor, qtype: qtype) -> torch.Tensor: """ Un-Pack int4 / int2 weights (packed in a uint8) into a torch.int8 tensor What un-packing means? Assume we have packed 4 2-bit values in 8-bit @@ -100,10 +88,10 @@ def unpack_weights(uint8weights: torch.Tensor, bitsdtype: qbitsdtype) -> torch.T Args: uint8weights (`torch.Tensor`): The packed tensor in `torch.uint8` precision - bitsdtype (`quanto.qbitsdtype`): - The desired `bitsdtype` - can be `quanto.int2`, `quanto.int4` + qtype (`quanto.qbitsdtype`): + The desired `qtype` - can be `quanto.int2`, `quanto.int4` """ - bits = bitsdtype.bits + bits = qtype.bits unpacked = [] values_per_item = 8 // bits @@ -121,9 +109,7 @@ def rshift(t: torch.Tensor, bits: int): return torch.cat(unpacked).to(torch.int8) -def absmax_scale( - base: torch.Tensor, itype: torch.Tensor.dtype = torch.int8, axis: Optional[int] = None -) -> torch.Tensor: +def absmax_scale(base: torch.Tensor, qtype: qtype = qint8, axis: Optional[int] = None) -> torch.Tensor: """Evaluate the quantization scale using the absmax algorithm. The Absolute Maximum quantization algorithm is a symmetrical quantization @@ -133,7 +119,7 @@ def absmax_scale( Args: base (`torch.Tensor`): the base tensor on which the scale will be applied. - itype (`torch.Tensor.dtype`): the target internal dtype for quantization. + qtype (`quanto.qtype`): the target qtype for quantization. axis (`int`): the index of the axis to preserve, or -1 for the last one. Defaults to None to reduce all axis. @@ -146,7 +132,7 @@ def absmax_scale( else: dim = axis_to_dim(abs_base, axis) qranges = torch.amax(torch.abs(base), dim=dim, keepdim=True) - info = dtype_info(itype) + info = dtype_info(qtype.dtype) return qranges / info.max @@ -168,10 +154,10 @@ class Quantizer(Function): """ @staticmethod - def forward(ctx, base, itype: torch.Tensor.dtype = torch.int8, scale=None): - info = dtype_info(itype) + def forward(ctx, base, qtype: qtype = qint8, scale=None): + info = dtype_info(qtype.dtype) if scale is None: - scale = absmax_scale(base, itype) + scale = absmax_scale(base, qtype) elif scale.ndim > 0: if torch.squeeze(scale).ndim > 1: raise ValueError("Quantizing along multiple axis is not supported") @@ -180,12 +166,12 @@ def forward(ctx, base, itype: torch.Tensor.dtype = torch.int8, scale=None): "When quantizing per-axis, the scale must be broadcastable to the base (Tip: try to add missing dims of length zero)." ) data = base / scale - if not itype.is_floating_point: + if not qtype.is_floating_point: data = torch.round(data) - data = torch.clamp(data, min=info.min, max=info.max).to(itype) + data = torch.clamp(data, min=info.min, max=info.max).to(qtype.dtype) # The instantiation of the quantized tensor must happen within the context of the Function # for the autograd magic to work. - return QTensor(data, scale) + return QTensor(qtype, data, scale) @staticmethod def backward(ctx, gO): @@ -196,11 +182,12 @@ def backward(ctx, gO): class Dequantizer(Function): @staticmethod def forward(ctx, t): - if t.itype == torch.int32: + # FIXME + if t.qtype == torch.int32: # The dequantization operation requires data to be cast to the scale float type before multiplication # by the scale, but this might actually overflow for float16/bfloat16 return (t._scale.to(torch.float32) * t._data).to(t._scale.dtype) - elif t.itype.is_floating_point: + elif t.qtype.is_floating_point: # Upcast explicitly to the scale dtype return t._scale * t._data.to(t._scale.dtype) return t._scale * t._data @@ -213,7 +200,7 @@ def backward(ctx, gO): class QTensor(torch.Tensor): @staticmethod - def __new__(cls, data, scale, requires_grad=False): + def __new__(cls, qtype, data, scale, requires_grad=False): assert data.device == scale.device # This constructor can ONLY create leaf Tensors wrt autograd. # Use QTensor.from_tensor(t) to get a non-leaf Tensor wrt autograd. @@ -221,7 +208,8 @@ def __new__(cls, data, scale, requires_grad=False): cls, data.size(), strides=data.stride(), dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) - def __init__(self, data, scale, requires_grad=False): + def __init__(self, qtype, data, scale, requires_grad=False): + self._qtype = qtype self._axis = None if scale.ndim > 0: if torch.squeeze(scale).ndim > 1: @@ -250,9 +238,9 @@ def __repr__(self): # Zero out missing values for printing return f"QTensor({self._data}, scale={self._scale}, public_dtype={self.dtype}{autograd_info})" @classmethod - def quantize(cls, base, itype=torch.int8, scale=None): + def quantize(cls, base, qtype=qint8, scale=None): """Differentiable quantization function""" - return Quantizer.apply(base, itype, scale) + return Quantizer.apply(base, qtype, scale) def dequantize(self): """Differentiable dequantization function""" @@ -263,8 +251,8 @@ def axis(self): return self._axis @property - def itype(self): - return self._data.dtype + def qtype(self): + return self._qtype def __tensor_flatten__(self): return ["_data", "_scale"], None @@ -318,9 +306,9 @@ class AffineQuantizer(Function): """A standard affine quantizer.""" @staticmethod - def forward(ctx, base, itype, axis=None, pack=False): - assert isinstance(itype, qbitsdtype) - bits = itype.bits + def forward(ctx, base, qtype: qtype, axis=None, pack=False): + assert qtype in (qint2, qint4) + bits = qtype.bits if axis is None: rmin = torch.min(base) rmax = torch.max(base) @@ -335,11 +323,9 @@ def forward(ctx, base, itype, axis=None, pack=False): data = torch.clamp(torch.round((base - rmin) / scale), min=0, max=2**bits - 1).to(torch.int8) if pack and data.dtype == torch.int8: - data = pack_weights(data, itype) - - data.itype = itype + data = pack_weights(data, qtype) - return QBitsTensor(data, scale, zeropoint) + return QBitsTensor(qtype, data, scale, zeropoint) @staticmethod def backward(ctx, gO): @@ -351,11 +337,11 @@ class QBitsToQTensor(Function): @staticmethod def forward(ctx, t): if t.packed: - unpacked_data = unpack_weights(t._data, t.itype) + unpacked_data = unpack_weights(t._data, t.qtype) else: unpacked_data = t._data int8_data = unpacked_data.to(torch.int8) - t._zeropoint.to(torch.int8) - return QTensor(int8_data, t._scale) + return QTensor(qint8, int8_data, t._scale) @staticmethod def backward(ctx, gO): @@ -364,20 +350,20 @@ def backward(ctx, gO): class QBitsTensor(QTensor): @staticmethod - def __new__(cls, data, scale, zeropoint, requires_grad=False): + def __new__(cls, qtype, data, scale, zeropoint, requires_grad=False): assert data.device == scale.device assert data.device == zeropoint.device packed = data.dtype == torch.uint8 size = data.size() if packed: # Fixme: create a PackedIntTensor subclass to store the packed / shape info - size = (size[0] * (8 // data.itype.bits), *size[1:]) + size = (size[0] * (8 // qtype.bits), *size[1:]) return torch.Tensor._make_wrapper_subclass( cls, size, strides=data.stride(), dtype=scale.dtype, device=data.device, requires_grad=requires_grad ) - def __init__(self, data, scale, zeropoint, requires_grad=False): - super().__init__(data, scale, requires_grad=requires_grad) + def __init__(self, qtype, data, scale, zeropoint, requires_grad=False): + super().__init__(qtype, data, scale, requires_grad=requires_grad) self._zeropoint = zeropoint self.packed = data.dtype == torch.uint8 @@ -388,9 +374,9 @@ def __repr__(self): return f"QBitsTensor({self._data}, scale={self._scale}, zeropoint={self._zeropoint}, dtype={self.dtype}{autograd_info})" @classmethod - def quantize(cls, base, itype=int4, axis=None, pack=True): + def quantize(cls, base, qtype=qint4, axis=None, pack=True): """Differentiable quantization function""" - return AffineQuantizer.apply(base, itype, axis, pack) + return AffineQuantizer.apply(base, qtype, axis, pack) def qtensor(self): return QBitsToQTensor.apply(self) @@ -399,8 +385,8 @@ def dequantize(self): return self.qtensor().dequantize() @property - def itype(self): - return self._data.itype + def qtype(self): + return self._qtype def __tensor_flatten__(self): return ["_data", "_scale", "_zeropoint"], None @@ -417,10 +403,8 @@ def __torch_dispatch__(cls, op, types, args, kwargs=None): if op.overloadpacket is torch.ops.aten.detach: t = args[0] data = op(t._data) - # Fixme: we should not do this manually, and use a dedicated subclass - data.itype = t._data.itype scale = op(t._scale) zeropoint = op(t._zeropoint) - return QBitsTensor(data, scale, zeropoint) + return QBitsTensor(t._qtype, data, scale, zeropoint) args, kwargs = pytree.tree_map_only(QBitsTensor, lambda x: x.qtensor(), (args, kwargs or {})) return op(*args, **kwargs) diff --git a/quanto/tensor/ops.py b/quanto/tensor/ops.py index a4abac99..900b8da9 100644 --- a/quanto/tensor/ops.py +++ b/quanto/tensor/ops.py @@ -6,6 +6,7 @@ import torch from . import QTensor, dtype_info, qfallback +from .qtype import qint8, qint32 __all__ = ["get_qtensor_op_dispatch", "register_qtensor_op"] @@ -72,10 +73,10 @@ def is_scalar(t): @register_qtensor_op([torch.ops.aten._to_copy]) def _to_copy(op, t, dtype=None, **kwargs): # For data, ignore dtype and use the inner type instead - out_data = op(t._data, dtype=t.itype, **kwargs) + out_data = op(t._data, dtype=t._data.dtype, **kwargs) # Apply the new dtype on the scale only out_scale = op(t._scale, dtype=dtype, **kwargs) - return QTensor(out_data, out_scale) + return QTensor(t.qtype, out_data, out_scale) @register_qtensor_op([torch.ops.aten.detach]) @@ -83,20 +84,25 @@ def detach(op, t): # Detach both data and scale out_data = op(t._data) out_scale = op(t._scale) - return QTensor(out_data, out_scale) + return QTensor(t.qtype, out_data, out_scale) @register_qtensor_op([torch.ops.aten.cat]) def cat(op, inputs, dim=0): if len(inputs) == 2: t1, t2 = inputs - if isinstance(t1, QTensor) and isinstance(t2, QTensor) and torch.equal(t1._scale, t2._scale): - if t1.itype.is_floating_point or t2.itype.is_floating_point: + if ( + isinstance(t1, QTensor) + and isinstance(t2, QTensor) + and torch.equal(t1._scale, t2._scale) + and t1.qtype == t2.qtype + ): + if t1.qtype.is_floating_point or t2.qtype.is_floating_point: # Cat is not supported for float8 return qfallback(op, inputs, dim) # Only quantized tensors with identical scales can be concatenated out_data = op([t1._data, t2._data], dim) - return QTensor(out_data, t1._scale) + return QTensor(t1.qtype, out_data, t1._scale) return qfallback(op, inputs, dim) @@ -113,7 +119,7 @@ def lt(op, input, other): qargs=[QArg(index=0, axis=[None, -1]), QArg(index=1, axis=[None]), QArg(index=2, axis=[None, -1])], ) def addmm(op, input, mat1, mat2, beta=1, alpha=1): - if alpha != 1 or beta != 1: + if alpha != 1 or beta != 1 or mat1.qtype != qint8 or mat2.qtype != qint8: return qfallback(op, input, mat1, mat2, beta=beta, alpha=alpha) # Do the operation with data cast to float32 out_data = op( @@ -124,18 +130,19 @@ def addmm(op, input, mat1, mat2, beta=1, alpha=1): alpha=alpha, ) out_scale = mat1._scale * mat2._scale - return QTensor(out_data.to(torch.int32), out_scale) + return QTensor(qint32, out_data.to(torch.int32), out_scale) @register_qtensor_op([torch.ops.aten.clone]) def clone(op, t, memory_format=torch.preserve_format): out_data = op(t._data, memory_format=memory_format) out_scale = op(t._scale, memory_format=memory_format) - return QTensor(out_data, out_scale) + return QTensor(t.qtype, out_data, out_scale) @register_qtensor_op([torch.ops.aten.copy_]) def copy_(op, dest, src): + assert dest.qtype == src.qtype dest._data = op(dest._data, src._data) dest._scale = op(dest._scale, src._scale) return dest @@ -146,24 +153,26 @@ def div(op, input, other): if not is_scalar(other): return op(input.dequantize(), other) # We just divide the scale - return QTensor(input._data, op(input._scale, other)) + return QTensor(input.qtype, input._data, op(input._scale, other)) @register_qtensor_op([torch.ops.aten.dot], qargs=[QArg(index=0, axis=[None]), QArg(index=1, axis=[None])]) def dot(op, input, other): + if input.qtype != qint8 or other.qtype != qint8: + return qfallback(op, input, other) # Cast data to float32 and do the operation out_data = op(input._data.to(torch.float32), other._data.to(torch.float32)) out_scale = input._scale * other._scale - return QTensor(out_data.to(torch.int32), out_scale) + return QTensor(qint32, out_data.to(torch.int32), out_scale) @register_qtensor_op([torch.ops.aten.neg]) def neg(op, input, *args, **kwargs): - if input.itype.is_floating_point: + if input.qtype.is_floating_point: # Neg is not supported for float8 return op(input.dequantize(), *args, **kwargs) out_data = op(input._data, *args, **kwargs) - return QTensor(out_data, input._scale) + return QTensor(input.qtype, out_data, input._scale) @register_qtensor_op( @@ -180,7 +189,7 @@ def unary_type_agnostic_op(op, input, *args, **kwargs): # When quantization is per-tensor, rhese operations can be transparently applied # without modifying the scale. out_data = op(input._data, *args, **kwargs) - return QTensor(out_data, input._scale) + return QTensor(input.qtype, out_data, input._scale) @register_qtensor_op([torch.ops.aten.is_same_size]) @@ -196,6 +205,8 @@ def linear(op, input, weight, bias=None): not isinstance(input, QTensor) or input.axis is not None or not isinstance(weight, QTensor) + or input.qtype != qint8 + or weight.qtype != qint8 or (bias is not None and not isinstance(bias, QTensor)) ): return qfallback(op, input, weight, bias=bias) @@ -207,25 +218,29 @@ def linear(op, input, weight, bias=None): # Weights are actually transposed inside the operation weight_scale = weight._scale.t() out_scale = input_scale * weight_scale - return QTensor(out_data.to(torch.int32), out_scale) + return QTensor(qint32, out_data.to(torch.int32), out_scale) @register_qtensor_op([torch.ops.aten.bmm], qargs=[QArg(index=0, axis=[None]), QArg(index=1, axis=[None, -1])]) def bmm(op, input, other): + if input.qtype != qint8 or other.qtype != qint8: + return qfallback(op, input, other) # Cast data to float32 and do the operation out_data = op(input._data.to(torch.float32), other._data.to(torch.float32)) out_scale = input._scale * other._scale - return QTensor(out_data.to(torch.int32), out_scale) + return QTensor(qint32, out_data.to(torch.int32), out_scale) @register_qtensor_op([torch.ops.aten.mm], qargs=[QArg(index=0, axis=[None]), QArg(index=1, axis=[None, -1])]) def mm(op, input, other): + if input.qtype != qint8 or other.qtype != qint8: + return qfallback(op, input, other) n, m = input.shape p = other.shape[-1] if ( input.device.type == "cuda" - and input.itype == torch.int8 - and other.itype == torch.int8 + and input.qtype == qint8 + and other.qtype == qint8 and n > 16 and n % 8 == 0 and m % 8 == 0 @@ -237,58 +252,70 @@ def mm(op, input, other): # Cast data to float32 and do the operation out_data = op(input._data.to(torch.float32), other._data.to(torch.float32)) out_scale = input._scale * other._scale - return QTensor(out_data.to(torch.int32), out_scale) + return QTensor(qint32, out_data.to(torch.int32), out_scale) @register_qtensor_op([torch.ops.aten.mul]) def mul(op, input, other): # If one of the multiplicands is a scalar, just multiply the scale if is_scalar(input): - return QTensor(other._data, input * other._scale) + return QTensor(other.qtype, other._data, input * other._scale) if is_scalar(other): - return QTensor(input._data, other * input._scale) - if not isinstance(input, QTensor) or not isinstance(other, QTensor): + return QTensor(input.qtype, input._data, other * input._scale) + if ( + not isinstance(input, QTensor) + or not isinstance(other, QTensor) + or input.qtype != qint8 + or other.qtype != qint8 + ): return qfallback(op, input, other) # Cast int8 data to int32 and do the operation out_data = op(input._data.to(torch.int32), other._data.to(torch.int32)) out_scale = input._scale * other._scale - return QTensor(out_data, out_scale) + return QTensor(qint32, out_data, out_scale) @register_qtensor_op([torch.ops.aten.relu]) def relu(op, input): - if input.itype.is_floating_point: + if input.qtype.is_floating_point: # Relu is not supported for float8 types return qfallback(op, input) out_data = op(input._data) - return QTensor(out_data, input._scale) + return QTensor(input.qtype, out_data, input._scale) @register_qtensor_op([torch.ops.aten._softmax]) def _softmax(op, input, dim, half_to_float): # Softmax must be performed in float - out_data = op(input.dequantize(), dim, half_to_float) + float_data = op(input.dequantize(), dim, half_to_float) # Since softmax is normalized, we know the optimal scale - out_scale = torch.tensor(1 / dtype_info(input.itype).max, dtype=input.dtype).to(input.device) - return QTensor.quantize(out_data, input.itype, out_scale) + out_scale = torch.tensor(1 / dtype_info(input.qtype.dtype).max, dtype=input._scale.dtype).to(input.device) + return QTensor.quantize(float_data, input.qtype, out_scale) @register_qtensor_op([torch.ops.aten.stack]) def stack(op, inputs, dim=0): if len(inputs) == 2: t1, t2 = inputs - if isinstance(t1, QTensor) and isinstance(t2, QTensor) and torch.equal(t1._scale, t2._scale): + if ( + isinstance(t1, QTensor) + and isinstance(t2, QTensor) + and torch.equal(t1._scale, t2._scale) + and t1.qtype == t2.qtype + ): # Only quantized tensors with identical scales can be stacked out_data = op([t1._data, t2._data], dim) - return QTensor(out_data, t1._scale) + return QTensor(t1.qtype, out_data, t1._scale) return qfallback(inputs, dim) @register_qtensor_op([torch.ops.aten.split]) def split(op, input, *args, **kwargs): + if input.axis is not None: + return qfallback(op, input, *args, **kwargs) out_datas = op(input._data, *args, **kwargs) - return [QTensor(out_data, input._scale) for out_data in out_datas] + return [QTensor(input.qtype, out_data, input._scale) for out_data in out_datas] @register_qtensor_op([torch.ops.aten.transpose, torch.ops.aten.t]) @@ -298,7 +325,7 @@ def transpose(op, input, *args): if input.axis is not None: # We need to transpose also the scale out_scale = op(out_scale, *args) - return QTensor(out_data, out_scale) + return QTensor(input.qtype, out_data, out_scale) @register_qtensor_op([torch.ops.aten.view, torch.ops.aten._unsafe_view], qargs=[QArg(index=0, axis=[None, -1])]) @@ -306,14 +333,14 @@ def view(op, input, *shape): out_data = op(input._data, *shape) if input.axis is None: # The view is transparent for QTensor with scalar scales - return QTensor(out_data, input._scale) + return QTensor(input.qtype, out_data, input._scale) # The tensor is quantized along the last axis assert input.axis == -1 # We can only perform the view if the last axis is not modified if input._scale.shape[-1] == out_data.shape[-1]: out_scale_shape = (1,) * (out_data.ndim - 1) + (input._scale.shape[-1],) out_scale = input._scale.view(out_scale_shape) - return QTensor(out_data, out_scale) + return QTensor(input.qtype, out_data, out_scale) return qfallback(op, input, *shape) @@ -323,4 +350,4 @@ def where(op, condition, input, other): raise NotImplementedError float_data = op(condition, input.dequantize(), other) # We requantize with the input scale - return QTensor.quantize(float_data, input.itype, input._scale) + return QTensor.quantize(float_data, input.qtype, input._scale) diff --git a/quanto/tensor/qtype.py b/quanto/tensor/qtype.py new file mode 100644 index 00000000..cd9e8805 --- /dev/null +++ b/quanto/tensor/qtype.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +import torch + + +__all__ = ["qtype", "qint2", "qint4", "qint8", "qint16", "qint32", "qfloat8", "qfloat8_e4m3fn", "qfloat8_e5m2"] + + +@dataclass +class qtype: + """A quantized type class mimicking torch dtype""" + + name: str + is_floating_point: bool + bits: int + # This defines the storage dtype + dtype: torch.dtype + + def __str__(self): + return f"quanto.{self.name}" + + def __hash__(self): + return hash(str(self)) + + +qint2 = qtype("qint2", is_floating_point=False, bits=2, dtype=torch.int8) +qint4 = qtype("qint4", is_floating_point=False, bits=4, dtype=torch.int8) +qint8 = qtype("qint8", is_floating_point=False, bits=8, dtype=torch.int8) +qint16 = qtype("qint16", is_floating_point=False, bits=16, dtype=torch.int16) +qint32 = qtype("qint32", is_floating_point=False, bits=32, dtype=torch.int32) +# Alias the float8 representation that has the better support and inference efficiency +qfloat8 = qtype("qfloat8", is_floating_point=True, bits=8, dtype=torch.float8_e4m3fn) +qfloat8_e4m3fn = qtype("qfloat8_e4m3fn", is_floating_point=True, bits=8, dtype=torch.float8_e4m3fn) +qfloat8_e5m2 = qtype("qfloat8_e5m2", is_floating_point=True, bits=8, dtype=torch.float8_e5m2) diff --git a/test/helpers.py b/test/helpers.py index 52cf8d73..f6d7be61 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -4,7 +4,7 @@ import torch from packaging import version -from quanto import QTensor, absmax_scale +from quanto import QTensor, absmax_scale, qint8 def torch_min_version(v): @@ -33,10 +33,10 @@ def random_tensor(shape, dtype=torch.float32): return torch.rand(shape, dtype=dtype) * 2 - 1 -def random_qtensor(shape, itype=torch.int8, dtype=torch.float32, axis=None): +def random_qtensor(shape, qtype=qint8, dtype=torch.float32, axis=None): t = random_tensor(shape, dtype) - scale = absmax_scale(t, itype=itype, axis=axis) - return QTensor.quantize(t, itype=itype, scale=scale) + scale = absmax_scale(t, qtype=qtype, axis=axis) + return QTensor.quantize(t, qtype=qtype, scale=scale) def q_assert_close(x: torch.Tensor, xq: QTensor, atol: float = None, rtol: float = None): diff --git a/test/library/test_unpack.py b/test/library/test_unpack.py index 9c5484cd..cc66146c 100644 --- a/test/library/test_unpack.py +++ b/test/library/test_unpack.py @@ -4,7 +4,7 @@ import torch from quanto.library import disable_extensions -from quanto.tensor.core import int2, int4, pack_weights +from quanto.tensor.core import pack_weights, qint2, qint4 @pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"]) @@ -13,7 +13,7 @@ def test_unpack(bits, shape, use_ext, device): qmax = 2**bits a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) - bitsdtype = int2 if bits == 2 else int4 + bitsdtype = qint2 if bits == 2 else qint4 packed_a = pack_weights(a, bitsdtype) context = nullcontext() if use_ext else disable_extensions() with context: diff --git a/test/model/test_quantize_mlp.py b/test/model/test_quantize_mlp.py index 1791a8c4..0ac2f3fb 100644 --- a/test/model/test_quantize_mlp.py +++ b/test/model/test_quantize_mlp.py @@ -2,7 +2,7 @@ import torch from helpers import assert_similar, random_qtensor -from quanto import Calibration, QLinear, QTensor, freeze, quantize +from quanto import Calibration, QLinear, QTensor, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize class MLP(torch.nn.Module): @@ -48,24 +48,24 @@ def _test_quantize_mlp(weights, activations, frozen, device): assert_similar(output, qoutput, atol=1e-2) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) def test_quantize_mlp_weights_only(weights, frozen, device): _test_quantize_mlp(weights, None, frozen, device) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") def test_quantize_mlp_int8_activations(weights, frozen, device): - _test_quantize_mlp(weights, torch.int8, frozen, device) + _test_quantize_mlp(weights, qint8, frozen, device) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize( "activations", - [None, torch.int8, torch.float8_e5m2, torch.float8_e4m3fn], - ids=["a-float", "a-int8", "a-float8-e5m2", "a-float8-e4m3"], + [None, qint8, qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-float", "a-qint8", "a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.parametrize("frozen", [True, False], ids=["frozen", "non-frozen"]) @pytest.mark.skip_device("mps") diff --git a/test/nn/test_calibrate.py b/test/nn/test_calibrate.py index 93f7cf6d..49f8ecf7 100644 --- a/test/nn/test_calibrate.py +++ b/test/nn/test_calibrate.py @@ -2,7 +2,7 @@ import torch from helpers import random_qtensor -from quanto import Calibration +from quanto import Calibration, qfloat8_e4m3fn, qfloat8_e5m2, qint8 from quanto.nn import QLinear @@ -18,7 +18,7 @@ def _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activation # Calibrate to adjust input and output scales and set the correct dtype with torch.no_grad(), Calibration(): qout = qlinear(qinputs) - assert qout.itype == activations + assert qout.qtype == activations assert torch.any(qlinear.input_scale != 1) assert torch.any(qlinear.output_scale != 1) @@ -27,7 +27,7 @@ def _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, activation @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) def test_calibrate_qlinear_activations_int8(batch_size, tokens, embeddings, use_bias, device): - _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, torch.int8, device) + _test_calibrate_qlinear(batch_size, tokens, embeddings, use_bias, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @@ -35,8 +35,8 @@ def test_calibrate_qlinear_activations_int8(batch_size, tokens, embeddings, use_ @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], - ids=["a-float8-e5m2", "a-float8-e4m3"], + [qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.skip_device("mps") def test_calibrate_qlinear_activations_float8(batch_size, tokens, embeddings, use_bias, activations, device): @@ -66,17 +66,17 @@ def forward(self, input): assert torch.any(model.linear1.output_scale != 1) assert torch.any(model.linear2.input_scale != 1) assert torch.any(model.linear2.output_scale != 1) - assert qout.itype == activations + assert qout.qtype == activations def test_calibrate_custom_module_activations_int8(device): - _test_calibrate_custom_module(torch.int8, device) + _test_calibrate_custom_module(qint8, device) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], - ids=["a-float8-e5m2", "a-float8-e4m3"], + [qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.skip_device("mps") def test_calibrate_custom_module_activations_float8(activations, device): diff --git a/test/nn/test_custom_qmodule.py b/test/nn/test_custom_qmodule.py deleted file mode 100644 index 94e6f0bf..00000000 --- a/test/nn/test_custom_qmodule.py +++ /dev/null @@ -1,153 +0,0 @@ -import os -from tempfile import TemporaryDirectory - -import pytest -import torch -from helpers import q_assert_close, random_qtensor - -from quanto import Calibration, QModuleMixin, QTensor, freeze, register_qmodule - - -class Conv1D(torch.nn.Module): - """ - 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). - - Basically works like a linear layer but the weights are transposed. - - Args: - nf (`int`): The number of output features. - nx (`int`): The number of input features. - """ - - def __init__(self, nf, nx): - super().__init__() - self.nf = nf - self.weight = torch.nn.Parameter(torch.empty(nx, nf)) - self.bias = torch.nn.Parameter(torch.zeros(nf)) - torch.nn.init.normal_(self.weight, std=0.02) - - def forward(self, x): - size_out = x.size()[:-1] + (self.nf,) - x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) - x = x.view(size_out) - return x - - -@register_qmodule(Conv1D) -class QConv1D(QModuleMixin, Conv1D): - @classmethod - def from_module(cls, module): - nx, nf = module.weight.size() - qmodule = cls(nf, nx) - with torch.no_grad(): - qmodule.weight.copy_(module.weight) - qmodule.bias.copy_(module.bias) - return qmodule.to(module.weight.device) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # If needed, quantize inputs, weights and bias - if isinstance(input, QTensor): - if input.itype == torch.int32: - # Reduce input bitwidth - input = input.rescale(torch.int8, self.in_scale) - else: - input = QTensor.quantize(input, torch.int8, self.in_scale) - weight = self.weight - if not isinstance(weight, QTensor): - weight = QTensor.quantize(weight) - bias = self.bias - bias_scale = self.in_scale * weight._scale - if isinstance(bias, QTensor): - if bias._scale != bias_scale: - # This should only happen if we calibrate again a frozen module - bias = QTensor.rescale(torch.int32, bias_scale) - else: - bias = QTensor.quantize(bias, torch.int32, bias_scale) - # Operate on quantized tensors - size_out = input.size()[:-1] + (self.nf,) - out_int32 = torch.addmm(bias, input.view(-1, input.size(-1)), weight) - out_int32 = out_int32.view(size_out) - # Downscale - return out_int32.rescale(torch.int8, self.out_scale) - - def freeze(self): - # Replace float weights by quantized weights - self.weight = torch.nn.Parameter(QTensor.quantize(self.weight).to(self.weight.device)) - bias_scale = self.in_scale * self.weight._scale - self.bias = torch.nn.Parameter(QTensor.quantize(self.bias, torch.int32, bias_scale)) - - -@pytest.mark.skip("QConv1D does not work") -@pytest.mark.parametrize("batch_size", [1, 10]) -@pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) -def test_quantize_conv1d(batch_size, tokens, embeddings, device): - conv = Conv1D(embeddings, embeddings).to(device) - qconv = QConv1D.from_module(conv) - qinputs = random_qtensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) - # Calibrate and obtain quantized outputs - with torch.no_grad(), Calibration(): - qout = qconv(qinputs) - # Freeze to set quantized weights - freeze(qconv) - # Align conv weights with quantized conv weights for comparison - conv.weight = torch.nn.Parameter(qconv.weight.dequantize()) - conv.bias = torch.nn.Parameter(qconv.bias.dequantize()) - out = conv(qinputs.dequantize()) - q_assert_close(out, qout) - # Now run an inference without calibrating - with torch.no_grad(): - int_qout = qconv(qinputs) - assert qout._scale == int_qout._scale - # There may be a slight difference, but of at most one quantization interval - assert torch.max(torch.abs(qout._data - int_qout._data)) <= 1 - - -@pytest.mark.skip("QConv1D does not work") -def test_qconv1d_serialization(): - tokens = 10 - embeddings = 32 - conv = Conv1D(embeddings, embeddings) - qconv = QConv1D.from_module(conv) - qinputs = random_qtensor((1,) + (tokens, embeddings), dtype=torch.float32) - # Calibrate and obtain quantized outputs - with torch.no_grad(), Calibration(): - qconv(qinputs) - # Freeze conv to store quantized weights and biases - qconv.freeze() - with TemporaryDirectory() as tmpdir: - qconv_file = os.path.join(tmpdir, "qconv.pt") - torch.save(qconv.state_dict(), qconv_file) - qconv_reloaded = QConv1D(embeddings, embeddings) - # When reloading we must assign instead of copying to force quantized tensors assignment - qconv_reloaded.load_state_dict(torch.load(qconv_file), assign=True) - for attr in ["weight", "bias"]: - t = getattr(qconv, attr) - if t is not None: - t_reloaded = getattr(qconv_reloaded, attr) - assert torch.equal(t._data, t_reloaded._data) - assert torch.equal(t._scale, t_reloaded._scale) - for attr in ["in_scale", "out_scale"]: - v = getattr(qconv, attr) - v_reloaded = getattr(qconv_reloaded, attr) - assert torch.equal(v, v_reloaded) - - -@pytest.mark.skip("QConv1D does not work") -@pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) -def test_qconv1d_gradient(tokens, embeddings, device): - # We use a batch size of 1 to simplify gradient manual calculations - batch_size = 1 - conv = Conv1D(embeddings, embeddings).to(device) - qconv = QConv1D.from_module(conv) - assert qconv.weight.requires_grad is True - assert qconv.bias.requires_grad is True - qinputs = random_qtensor((batch_size,) + (tokens, embeddings), dtype=torch.float32).to(device) - qout = qconv(qinputs) - gradient = torch.randn(qout.size()).to(device) - qout.backward(gradient) - # Compute gradients manually and compare - bias_gradient = torch.sum(gradient, axis=[0, 1]) - assert torch.allclose(qconv.bias.grad, bias_gradient) - # FIXME: gradient calculation is wrong because of the transposed weights - weight_gradient = torch.matmul(gradient.squeeze().t(), qinputs.dequantize().squeeze()) - assert torch.allclose(qconv.weight.grad, weight_gradient) diff --git a/test/nn/test_qattention.py b/test/nn/test_qattention.py index a8426498..b48ccc67 100644 --- a/test/nn/test_qattention.py +++ b/test/nn/test_qattention.py @@ -7,7 +7,7 @@ from helpers import assert_similar, random_tensor from torch import nn -from quanto import Calibration, quantize +from quanto import Calibration, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize class RotaryEmbedding(nn.Module): @@ -157,7 +157,7 @@ def forward( return self.o_proj(attn_output) -def _test_quantize_attention(device, dtype=torch.float32, weights=torch.int8, activations=None): +def _test_quantize_attention(device, dtype=torch.float32, weights=qint8, activations=None): att = Attention().to(dtype).to(device) batch_size = 10 seq_len = 64 @@ -172,24 +172,24 @@ def _test_quantize_attention(device, dtype=torch.float32, weights=torch.int8, ac else: with torch.no_grad(), Calibration(): qoutputs = att(inputs) - atol = {None: 1e-4, torch.int8: 1e-3, torch.float8_e5m2: 1e-2, torch.float8_e4m3fn: 1e-2}[activations] + atol = {None: 1e-4, qint8: 1e-3, qfloat8_e5m2: 1e-2, qfloat8_e4m3fn: 1e-2}[activations] assert_similar(outputs, qoutputs, atol=atol) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) def test_quantize_attention_weights_only(weights, device): _test_quantize_attention(device, weights=weights) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) def test_quantize_attention_activations_int8(weights, device): - _test_quantize_attention(device, weights=weights, activations=torch.int8) + _test_quantize_attention(device, weights=weights, activations=qint8) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], + [qfloat8_e5m2, qfloat8_e4m3fn], ids=["a-float8-e5m2", "a-float8-e4m3"], ) @pytest.mark.skip_device("mps") diff --git a/test/nn/test_qlayernorm.py b/test/nn/test_qlayernorm.py index 97808680..ea3e3854 100644 --- a/test/nn/test_qlayernorm.py +++ b/test/nn/test_qlayernorm.py @@ -2,7 +2,7 @@ import torch from helpers import assert_similar, random_qtensor -from quanto import Calibration, QTensor +from quanto import Calibration, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint8 from quanto.nn import QLayerNorm @@ -10,20 +10,20 @@ def _test_quantize_layernorm(batch_size, tokens, embeddings, dtype, activations, # Instantiate a normalization layer norm = torch.nn.LayerNorm(embeddings).to(dtype).to(device) qnorm = QLayerNorm.from_module(norm, activations=activations) - qinputs = random_qtensor((batch_size,) + (tokens, embeddings), itype=activations, dtype=dtype).to(device) + qinputs = random_qtensor((batch_size,) + (tokens, embeddings), qtype=activations, dtype=dtype).to(device) # Calibrate to avoid clipping and to set the correct dtype with torch.no_grad(), Calibration(): qout = qnorm(qinputs) qout = qnorm(qinputs) assert isinstance(qout, QTensor) assert qout.dtype == dtype - assert qout.itype == activations + assert qout.qtype == activations # Compare with the float results out = norm(qinputs.dequantize()) # We need to increase atol for float16 dtype dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype] - # We also need to increase atol for float8 itypes - atol = {torch.int8: dtype_atol, torch.float8_e5m2: 5e-3, torch.float8_e4m3fn: 5e-3}[activations] + # We also need to increase atol for float8 qtypes + atol = {qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3}[activations] assert_similar(out, qout, atol=atol) @@ -31,20 +31,20 @@ def _test_quantize_layernorm(batch_size, tokens, embeddings, dtype, activations, @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.skip_device("cpu") def test_quantize_layernorm_float16_activations_int8(batch_size, tokens, embeddings, device): - _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float16, torch.int8, device) + _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float16, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) def test_quantize_layernorm_float32_activations_int8(batch_size, tokens, embeddings, device): - _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float32, torch.int8, device) + _test_quantize_layernorm(batch_size, tokens, embeddings, torch.float32, qint8, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], + [qfloat8_e5m2, qfloat8_e4m3fn], ids=["a-float8-e5m2", "a-float8-e4m3"], ) @pytest.mark.skip_device("cpu") @@ -57,7 +57,7 @@ def test_quantize_layernorm_float16_activations_float8(batch_size, tokens, embed @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], + [qfloat8_e5m2, qfloat8_e4m3fn], ids=["a-float8-e5m2", "a-float8-e4m3"], ) @pytest.mark.skip_device("cpu") diff --git a/test/nn/test_qlinear.py b/test/nn/test_qlinear.py index de96d994..80b0c96a 100644 --- a/test/nn/test_qlinear.py +++ b/test/nn/test_qlinear.py @@ -2,56 +2,56 @@ import torch from helpers import assert_similar, random_qtensor -from quanto import Calibration, QTensor, int4 +from quanto import Calibration, QTensor, qfloat8_e4m3fn, qfloat8_e5m2, qint4, qint8 from quanto.nn import QLinear def _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, activations, dtype, device): linear = torch.nn.Linear(embeddings, embeddings, bias=use_bias).to(dtype).to(device) qlinear = QLinear.from_module(linear, weights=weights, activations=activations) - assert qlinear.qweight().itype == weights + assert qlinear.qweight().qtype == weights qinputs = random_qtensor((batch_size,) + (tokens, embeddings), dtype=dtype).to(device) # Run an inference with Calibration to get the correct output dtype with torch.no_grad(), Calibration(): qout = qlinear(qinputs) if activations is not None: assert isinstance(qout, QTensor) - assert qout.itype == activations + assert qout.qtype == activations # Align linear weights with quantized linear weights for comparison linear.weight = torch.nn.Parameter(qlinear.qweight().dequantize()) out = linear(qinputs.dequantize()) # We need to increase atol for float16 dtype dtype_atol = {torch.float32: 1e-4, torch.float16: 1e-3}[dtype] - # We also need to increase atol for float8 itypes - atol = {None: dtype_atol, torch.int8: dtype_atol, torch.float8_e5m2: 5e-3, torch.float8_e4m3fn: 5e-3}[activations] + # We also need to increase atol for float8 qtypes + atol = {None: dtype_atol, qint8: dtype_atol, qfloat8_e5m2: 5e-3, qfloat8_e4m3fn: 5e-3}[activations] assert_similar(out, qout, atol=atol) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.skip_device("cpu") def test_quantize_linear_float16_activations_int8(batch_size, tokens, embeddings, use_bias, weights, device): - _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, torch.int8, torch.float16, device) + _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, qint8, torch.float16, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) def test_quantize_linear_float32_activations_int8(batch_size, tokens, embeddings, use_bias, weights, device): - _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, torch.int8, torch.float32, device) + _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, qint8, torch.float32, device) @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], - ids=["a-float8-e5m2", "a-float8-e4m3"], + [qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.skip_device("cpu") @pytest.mark.skip_device("mps") @@ -64,11 +64,11 @@ def test_quantize_linear_float16_activations_float8( @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.parametrize( "activations", - [torch.float8_e5m2, torch.float8_e4m3fn], - ids=["a-float8-e5m2", "a-float8-e4m3"], + [qfloat8_e5m2, qfloat8_e4m3fn], + ids=["a-qfloat8-e5m2", "a-qfloat8-e4m3"], ) @pytest.mark.skip_device("mps") def test_quantize_linear_float32_activations_float8( @@ -80,7 +80,7 @@ def test_quantize_linear_float32_activations_float8( @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) @pytest.mark.skip_device("cpu") def test_quantize_linear_float16_weight_only(batch_size, tokens, embeddings, use_bias, weights, device): _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float16, device) @@ -89,14 +89,14 @@ def test_quantize_linear_float16_weight_only(batch_size, tokens, embeddings, use @pytest.mark.parametrize("batch_size", [1, 10]) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) def test_quantize_linear_float32_weight_only(batch_size, tokens, embeddings, use_bias, weights, device): _test_quantize_linear(batch_size, tokens, embeddings, use_bias, weights, None, torch.float32, device) @pytest.mark.parametrize("tokens, embeddings", [(32, 32), (10, 32)]) -@pytest.mark.parametrize("activations", [None, torch.int8], ids=["a-float", "a-int8"]) -@pytest.mark.parametrize("weights", [int4, torch.int8], ids=["w-int4", "w-int8"]) +@pytest.mark.parametrize("activations", [None, qint8], ids=["a-float", "a-qint8"]) +@pytest.mark.parametrize("weights", [qint4, qint8], ids=["w-qint4", "w-qint8"]) def test_qlinear_gradient(tokens, embeddings, activations, weights, device): # We use a batch size of 1 to simplify gradient manual calculations batch_size = 1 diff --git a/test/nn/test_qmodule.py b/test/nn/test_qmodule.py index e47e9b8a..1f0144db 100644 --- a/test/nn/test_qmodule.py +++ b/test/nn/test_qmodule.py @@ -1,7 +1,7 @@ import pytest import torch -from quanto import QTensor +from quanto import QTensor, qint8 from quanto.nn import QLinear @@ -20,12 +20,12 @@ def test_qmodule_freeze(in_features, out_features, use_bias, dtype): qweight = qlinear.qweight() assert isinstance(qweight, QTensor) assert qweight.dtype == dtype - assert qweight.itype == torch.int8 + assert qweight.qtype == qint8 qlinear.freeze() assert qlinear.frozen assert isinstance(qlinear.weight, QTensor) assert qlinear.weight.dtype == dtype - assert qlinear.weight.itype == torch.int8 + assert qlinear.weight.qtype == qint8 if use_bias: assert not isinstance(qlinear.bias, QTensor) assert qlinear.bias.dtype == dtype diff --git a/test/tensor/ops/test_linear_dispatch.py b/test/tensor/ops/test_linear_dispatch.py index 0ad3c69e..c10495e1 100644 --- a/test/tensor/ops/test_linear_dispatch.py +++ b/test/tensor/ops/test_linear_dispatch.py @@ -2,7 +2,7 @@ import torch from helpers import q_assert_close, random_qtensor, random_tensor -from quanto import QTensor +from quanto import QTensor, qint16 @pytest.mark.parametrize("batch_size", [1, 10]) @@ -19,7 +19,7 @@ def test_linear(batch_size, tokens, embeddings, use_bias, dtype, weight_axis, de bias = random_tensor((embeddings,), dtype=dtype).to(device) # Bias must be quantized to int16 with the same scale as the product of the two int8 prod_scale = torch.squeeze(qinputs._scale * qweight._scale) - qbias = QTensor.quantize(bias, torch.int16, prod_scale) + qbias = QTensor.quantize(bias, qint16, prod_scale) else: qbias = None out = torch.nn.functional.linear( diff --git a/test/tensor/ops/test_quantized_dispatch.py b/test/tensor/ops/test_quantized_dispatch.py index 34c82b64..db6ad3a8 100644 --- a/test/tensor/ops/test_quantized_dispatch.py +++ b/test/tensor/ops/test_quantized_dispatch.py @@ -81,11 +81,11 @@ def test_cat(input_shape, device): qinputs = random_qtensor(input_shape, dtype=torch.float32).to(device) other = random_tensor(input_shape, dtype=torch.float32).to(device) # First, quantize other with the same scale - qother = QTensor.quantize(other, qinputs.itype, qinputs._scale) + qother = QTensor.quantize(other, qinputs.qtype, qinputs._scale) qcat = torch.cat([qinputs, qother]) assert isinstance(qcat, QTensor) q_assert_close(torch.cat([qinputs.dequantize(), qother.dequantize()]), qcat) # Now, verify that with different scales, the output is dequantized - qother = QTensor.quantize(other, qinputs.itype) + qother = QTensor.quantize(other, qinputs.qtype) qcat = torch.cat([qinputs, qother]) assert not isinstance(qcat, QTensor) diff --git a/test/tensor/test_absmax.py b/test/tensor/test_absmax.py index 1375c1d4..2174cc65 100644 --- a/test/tensor/test_absmax.py +++ b/test/tensor/test_absmax.py @@ -2,20 +2,18 @@ import torch from helpers import random_tensor -from quanto import absmax_scale +from quanto import absmax_scale, qfloat8_e4m3fn, qfloat8_e5m2, qint8 @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (2, 10), (10, 32, 32)]) -@pytest.mark.parametrize( - "itype", [torch.int8, torch.float8_e5m2, torch.float8_e4m3fn], ids=["int8", "float8_e5m2", "float8_e4m3"] -) +@pytest.mark.parametrize("qtype", [qint8, qfloat8_e5m2, qfloat8_e4m3fn], ids=["qint8", "qfloat8_e5m2", "qfloat8_e4m3"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"]) -def test_absmax_scale(input_shape, axis, dtype, itype, device): - if device.type == "mps" and itype.is_floating_point: +def test_absmax_scale(input_shape, axis, dtype, qtype, device): + if device.type == "mps" and qtype.is_floating_point: pytest.skip("Float8 are not supported on MPS device") a = random_tensor(input_shape, dtype=dtype).to(device) - scale = absmax_scale(a, itype, axis) + scale = absmax_scale(a, qtype, axis) assert scale.dtype == dtype if axis is None: assert scale.ndim == 0 diff --git a/test/tensor/test_qbitstensor.py b/test/tensor/test_qbitstensor.py index 767134e5..b1171f91 100644 --- a/test/tensor/test_qbitstensor.py +++ b/test/tensor/test_qbitstensor.py @@ -4,69 +4,69 @@ import torch from helpers import device_eq, q_assert_close, random_tensor -from quanto import QBitsTensor, int2, int4 +from quanto import QBitsTensor, qint2, qint4 @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("itype", [int2, int4], ids=["int2", "int4"]) +@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) @pytest.mark.parametrize("pack", [True, False], ids=["pack", "not-packed"]) -def test_quantize_integer_tensor(dtype, itype, device, pack): +def test_quantize_integer_tensor(dtype, qtype, device, pack): """This test verifies that an integer tensor in the correct range is preserved.""" - bits = itype.bits + bits = qtype.bits qmin = -(2 ** (bits - 1)) qmax = 2 ** (bits - 1) - 1 a = torch.tensor(range(qmin, qmax + 1), dtype=dtype).to(device) - qa = QBitsTensor.quantize(a, itype=itype, pack=pack) + qa = QBitsTensor.quantize(a, qtype=qtype, pack=pack) assert qa._data.dtype == torch.uint8 if pack else torch.int8 assert isinstance(qa, QBitsTensor) assert qa.dtype == dtype - assert qa.itype == itype + assert qa.qtype == qtype assert device_eq(qa.device, device) assert torch.equal(a, qa.dequantize()) @pytest.mark.parametrize("input_shape", [(10,), (12,), (10, 10), (12, 10), (32, 32)]) -@pytest.mark.parametrize("itype", [int2, int4], ids=["int2", "int4"]) +@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("zp", [-1, 0, 1], ids=["neg", "centered", "pos"]) -def test_quantize_per_tensor(input_shape, itype, dtype, zp, device): +def test_quantize_per_tensor(input_shape, qtype, dtype, zp, device): a = random_tensor(input_shape, dtype=dtype).to(device) + zp - qa = QBitsTensor.quantize(a, itype=itype) + qa = QBitsTensor.quantize(a, qtype=qtype) assert isinstance(qa, QBitsTensor) assert qa.dtype == dtype - assert qa.itype == itype + assert qa.qtype == qtype assert device_eq(qa.device, device) - if input_shape[0] % (8 // itype.bits) == 0: + if input_shape[0] % (8 // qtype.bits) == 0: assert qa.packed q_assert_close(a, qa) @pytest.mark.parametrize("axis", [0, 1, -1], ids=["first-axis", "second-axis", "last-axis"]) -@pytest.mark.parametrize("itype", [int2, int4], ids=["int2", "int4"]) +@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["qint2", "qint4"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("zp", [-1, 0, 1], ids=["neg", "centered", "pos"]) -def test_quantize_per_axis(axis, itype, dtype, zp, device): +def test_quantize_per_axis(axis, qtype, dtype, zp, device): a = random_tensor((32, 32), dtype=dtype).to(device) + zp - qa = QBitsTensor.quantize(a, itype=itype, axis=axis) + qa = QBitsTensor.quantize(a, qtype=qtype, axis=axis) assert isinstance(qa, QBitsTensor) assert qa.dtype == dtype - assert qa.itype == itype + assert qa.qtype == qtype assert device_eq(qa.device, device) q_assert_close(a, qa) -@pytest.mark.parametrize("itype", [int2, int4], ids=["int2", "int4"]) +@pytest.mark.parametrize("qtype", [qint2, qint4], ids=["int2", "int4"]) @pytest.mark.parametrize("axis", [0, None, -1], ids=["first-axis", "per-tensor", "last-axis"]) -def test_qbitstensor_serialization(itype, axis): +def test_qbitstensor_serialization(qtype, axis): a = random_tensor((5, 5), dtype=torch.float32) - qa = QBitsTensor.quantize(a, itype=itype, axis=axis) + qa = QBitsTensor.quantize(a, qtype=qtype, axis=axis) b = io.BytesIO() torch.save(qa, b) b.seek(0) qa_reloaded = torch.load(b) assert isinstance(qa_reloaded, QBitsTensor) - assert qa_reloaded.itype == qa.itype + assert qa_reloaded.qtype == qa.qtype assert qa_reloaded.dtype == qa.dtype assert torch.equal(qa_reloaded._data, qa._data) assert torch.equal(qa_reloaded._scale, qa._scale) diff --git a/test/tensor/test_qtensor.py b/test/tensor/test_qtensor.py index e40fbcda..62b791e0 100644 --- a/test/tensor/test_qtensor.py +++ b/test/tensor/test_qtensor.py @@ -5,44 +5,44 @@ import torch from helpers import assert_similar, device_eq, q_assert_close, random_qtensor, random_tensor -from quanto import QTensor, absmax_scale +from quanto import QTensor, absmax_scale, qfloat8_e4m3fn, qfloat8_e5m2, qint8, qint16, qint32 @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("itype", [torch.int8], ids=["int8"]) -def test_quantize_integer(input_shape, dtype, itype, device): +@pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) +def test_quantize_integer(input_shape, dtype, qtype, device): a = random_tensor(input_shape, dtype=dtype).to(device) - qa = QTensor.quantize(a, itype) + qa = QTensor.quantize(a, qtype) assert isinstance(qa, QTensor) assert qa.dtype == dtype - assert qa.itype == itype + assert qa.qtype == qtype assert device_eq(qa.device, device) q_assert_close(a, qa) @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("itype", [torch.float8_e5m2, torch.float8_e4m3fn], ids=["float8_e5m2", "float8_e4m3"]) +@pytest.mark.parametrize("qtype", [qfloat8_e5m2, qfloat8_e4m3fn], ids=["qfloat8_e5m2", "qfloat8_e4m3"]) @pytest.mark.skip_device("mps") -def test_quantize_float8(input_shape, dtype, itype, device): +def test_quantize_float8(input_shape, dtype, qtype, device): a = random_tensor(input_shape, dtype=dtype).to(device) - qa = QTensor.quantize(a, itype) + qa = QTensor.quantize(a, qtype) assert isinstance(qa, QTensor) assert qa.dtype == dtype - assert qa.itype == itype + assert qa.qtype == qtype assert device_eq(qa.device, device) assert_similar(a, qa, atol=5e-3) @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (2, 10), (10, 32, 32)]) -@pytest.mark.parametrize("itype", [torch.int8], ids=["int8"]) +@pytest.mark.parametrize("qtype", [qint8], ids=["qint8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"]) -def test_quantize_scale(input_shape, axis, dtype, itype, device): +def test_quantize_scale(input_shape, axis, dtype, qtype, device): a = random_tensor(input_shape, dtype=dtype).to(device) - scale = absmax_scale(a, itype, axis) - qa = QTensor.quantize(a, itype, scale) + scale = absmax_scale(a, qtype, axis) + qa = QTensor.quantize(a, qtype, scale) if axis is not None: if a.ndim == 1: # Quantization is actually per-tensor since the input tensor is a vector @@ -53,7 +53,7 @@ def test_quantize_scale(input_shape, axis, dtype, itype, device): else: assert qa.axis == axis assert isinstance(qa, QTensor) - assert qa.itype == itype + assert qa.qtype == qtype assert qa._scale.dtype == dtype assert device_eq(qa.device, device) q_assert_close(a, qa) @@ -61,14 +61,14 @@ def test_quantize_scale(input_shape, axis, dtype, itype, device): @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (10, 32, 32)]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) -@pytest.mark.parametrize("itype", [torch.int8, torch.int16, torch.int32], ids=["int8", "int16", "int32"]) -def test_instantiate(input_shape, dtype, itype, device): - max_value = min(1024, torch.iinfo(itype).max) - data = torch.randint(-max_value, max_value, input_shape, dtype=itype) - qa = QTensor(data, scale=torch.tensor(1.0 / max_value, dtype=dtype)).to(device) +@pytest.mark.parametrize("qtype", [qint8, qint16, qint32], ids=["qint8", "qint16", "qint32"]) +def test_instantiate(input_shape, dtype, qtype, device): + max_value = min(1024, torch.iinfo(qtype.dtype).max) + data = torch.randint(-max_value, max_value, input_shape, dtype=qtype.dtype) + qa = QTensor(qtype, data, scale=torch.tensor(1.0 / max_value, dtype=dtype)).to(device) assert torch.max(torch.abs(qa.dequantize())) <= 1 assert qa.dtype == dtype - assert qa.itype == itype + assert qa.qtype == qtype def test_quantized_tensor_serialization(): diff --git a/test/test_serialization.py b/test/test_serialization.py index 2b3957f2..16d26f34 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4,32 +4,37 @@ import torch from helpers import random_qtensor, random_tensor -from quanto import Calibration, QTensor, absmax_scale, freeze, quantize +from quanto import Calibration, QTensor, absmax_scale, freeze, qfloat8, qint8, quantize from quanto.nn import QLinear, QModuleMixin @pytest.mark.parametrize("input_shape", [(10,), (1, 10), (2, 10), (10, 32, 32)]) -@pytest.mark.parametrize("itype", [torch.int8], ids=["int8"]) +@pytest.mark.parametrize("qtype", [qint8, qfloat8], ids=["qint8", "qfloat8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) @pytest.mark.parametrize("axis", [None, 0, -1], ids=["per-tensor", "first-axis", "last-axis"]) -def test_quantized_tensor_serialization(input_shape, itype, dtype, axis): +def test_quantized_tensor_serialization(input_shape, qtype, dtype, axis): inputs = random_tensor(input_shape, dtype=dtype) - scale = absmax_scale(inputs, itype, axis) - qinputs = QTensor.quantize(inputs, itype, scale) + scale = absmax_scale(inputs, qtype, axis) + qinputs = QTensor.quantize(inputs, qtype, scale) b = io.BytesIO() torch.save(qinputs, b) b.seek(0) qinputs_reloaded = torch.load(b) - assert torch.equal(qinputs_reloaded._data, qinputs._data) + assert qinputs_reloaded.qtype == qtype assert torch.equal(qinputs_reloaded._scale, qinputs._scale) - # We cannot test dtype directly, as it is not set correctly by torch.load + if qtype.is_floating_point: + # Equality is not supported for float8 + assert torch.equal(qinputs_reloaded._data.to(torch.float32), qinputs._data.to(torch.float32)) + else: + assert torch.equal(qinputs_reloaded._data, qinputs._data) + # We cannot test dtype directly as it is not correctly set by torch.load assert qinputs_reloaded._scale.dtype == dtype assert qinputs_reloaded.axis == qinputs.axis @pytest.mark.parametrize("use_bias", [True, False], ids=["bias", "no-bias"]) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) -@pytest.mark.parametrize("activations", [None, torch.int8], ids=["a-float", "a-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) +@pytest.mark.parametrize("activations", [None, qint8], ids=["a-float", "a-qint8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) def test_qlinear_serialization(use_bias, activations, weights, dtype, device): if dtype == torch.float16 and device.type == "cpu": @@ -51,6 +56,7 @@ def test_qlinear_serialization(use_bias, activations, weights, dtype, device): qlinear_reloaded.load_state_dict(state_dict, assign=True) w = qlinear.weight w_reloaded = qlinear_reloaded.weight + assert w.qtype == w_reloaded.qtype assert torch.equal(w._data, w_reloaded._data) assert torch.equal(w._scale, w_reloaded._scale) assert w_reloaded.dtype == dtype @@ -75,7 +81,7 @@ def forward(self, inputs): return torch.nn.functional.softmax(self.output_layer(x), dim=-1) -@pytest.mark.parametrize("weights", [torch.int8], ids=["w-int8"]) +@pytest.mark.parametrize("weights", [qint8], ids=["w-qint8"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.float32], ids=["fp16", "fp32"]) def test_serialize_quantized_mlp(weights, dtype, device): if dtype == torch.float16 and device.type == "cpu":