Skip to content

Commit

Permalink
feat(quanto): introduce qtype
Browse files Browse the repository at this point in the history
This implies a lot of modifications but is functionally equivalent.
  • Loading branch information
dacorvo committed Feb 16, 2024
1 parent 1b66ea9 commit 6512a22
Show file tree
Hide file tree
Showing 25 changed files with 286 additions and 389 deletions.
6 changes: 3 additions & 3 deletions bench/generation/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig

from quanto import Calibration, freeze, quantize
from quanto import Calibration, freeze, qint8, quantize


CALIBRATION_PROMPT = "It was a bright cold day in April, and the clocks were striking thirteen."
Expand Down Expand Up @@ -165,8 +165,8 @@ def main():
if args.quantization in ("w8a8", "w8a16"):
print("quantizing")
start = time.time()
weights = torch.int8
activations = None if "a16" in args.quantization else torch.int8
weights = qint8
activations = None if "a16" in args.quantization else qint8
quantize(model, weights=weights, activations=activations)
if activations is not None:
print("Calibrating")
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/text-classification/sst2/quantize_sst2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers.pipelines.pt_utils import KeyDataset

from quanto import Calibration, freeze, quantize
from quanto import Calibration, freeze, qint8, quantize


def evaluate_model(model, tokenizer, dataset, device, batch_size):
Expand All @@ -22,7 +22,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size):


def keyword_to_itype(k):
return {"none": None, "int8": torch.int8}[k]
return {"none": None, "int8": qint8}[k]


def main():
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/text-generation/quantize_causal_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from quanto import Calibration, freeze, quantize
from quanto import Calibration, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize


@torch.no_grad()
Expand Down Expand Up @@ -51,7 +51,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size, samples=None,


def keyword_to_itype(k):
return {"none": None, "int8": torch.int8, "fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn}[k]
return {"none": None, "int8": qint8, "fp8_e5m2": qfloat8_e5m2, "fp8_e4m3": qfloat8_e4m3fn}[k]


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import datasets, transforms
from transformers import AutoModel

from quanto import Calibration, QTensor, freeze, int4, quantize
from quanto import Calibration, QTensor, freeze, qint4, qint8, quantize


def test(model, device, test_loader):
Expand Down Expand Up @@ -60,7 +60,7 @@ def train(log_interval, model, device, train_loader, optimizer, epoch):


def keyword_to_itype(k):
return {"none": None, "int4": int4, "int8": torch.int8}[k]
return {"none": None, "int4": qint4, "int8": qint8}[k]


def main():
Expand Down
16 changes: 8 additions & 8 deletions quanto/nn/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..tensor import QBitsTensor, QTensor, absmax_scale, qbitsdtype
from ..tensor import QBitsTensor, QTensor, absmax_scale, qint2, qint4, qint8, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand All @@ -11,12 +11,12 @@

@register_qmodule(torch.nn.Linear)
class QLinear(QModuleMixin, torch.nn.Linear):
def __init__(self, *args, weights: torch.dtype = torch.int8, **kwargs):
def __init__(self, *args, weights: qtype = qint8, **kwargs):
super().__init__(*args, **kwargs)
self.weights = weights

@classmethod
def from_module(cls, module, weights=torch.int8, activations: Optional[torch.dtype] = None):
def from_module(cls, module, weights=qint8, activations: Optional[qtype] = None):
qmodule = cls(
module.in_features,
module.out_features,
Expand All @@ -36,12 +36,12 @@ def qweight(self):
if isinstance(self.weight, QTensor):
return self.weight
# Quantize the weights per-axis
if isinstance(self.weights, torch.dtype):
if self.weights == qint8:
wscale = absmax_scale(self.weight, axis=0)
return QTensor.quantize(self.weight, itype=self.weights, scale=wscale)
elif isinstance(self.weights, qbitsdtype):
return QBitsTensor.quantize(self.weight, itype=self.weights, axis=0)
raise ValueError("Invalid quantized weights type")
return QTensor.quantize(self.weight, qtype=self.weights, scale=wscale)
elif self.weights in (qint2, qint4):
return QBitsTensor.quantize(self.weight, qtype=self.weights, axis=0)
raise ValueError(f"Invalid quantized weights type {self.weights}")

def qforward(self, input: torch.Tensor) -> torch.Tensor:
if self.activations is not None and not isinstance(input, QTensor):
Expand Down
4 changes: 2 additions & 2 deletions quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def qforward(self, input: torch.Tensor) -> torch.Tensor:

def forward(self, input: torch.Tensor) -> torch.Tensor:
def maybe_requantize(t, scale):
if t.itype == self.activations and t.axis is None:
if t.qtype == self.activations and t.axis is None:
return t
return QTensor.quantize(t.dequantize(), itype=self.activations, scale=scale)
return QTensor.quantize(t.dequantize(), qtype=self.activations, scale=scale)

if self.activations is not None and isinstance(input, QTensor):
input = maybe_requantize(input, self.input_scale)
Expand Down
1 change: 1 addition & 0 deletions quanto/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .core import *
from .qtype import *
Loading

0 comments on commit 6512a22

Please sign in to comment.