Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce qtype and release 0.0.12 #87

Merged
merged 3 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bench/generation/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig

from quanto import Calibration, freeze, quantize
from quanto import Calibration, freeze, qint8, quantize


CALIBRATION_PROMPT = "It was a bright cold day in April, and the clocks were striking thirteen."
Expand Down Expand Up @@ -165,8 +165,8 @@ def main():
if args.quantization in ("w8a8", "w8a16"):
print("quantizing")
start = time.time()
weights = torch.int8
activations = None if "a16" in args.quantization else torch.int8
weights = qint8
activations = None if "a16" in args.quantization else qint8
quantize(model, weights=weights, activations=activations)
if activations is not None:
print("Calibrating")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
from transformers.pipelines.pt_utils import KeyDataset

from quanto import Calibration, freeze, quantize
from quanto import Calibration, freeze, qint8, quantize


def evaluate_model(model, tokenizer, dataset, device, batch_size):
Expand All @@ -22,7 +22,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size):


def keyword_to_itype(k):
return {"none": None, "int8": torch.int8}[k]
return {"none": None, "int8": qint8}[k]


def main():
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/text-generation/quantize_causal_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from quanto import Calibration, freeze, quantize
from quanto import Calibration, freeze, qfloat8_e4m3fn, qfloat8_e5m2, qint8, quantize


@torch.no_grad()
Expand Down Expand Up @@ -51,7 +51,7 @@ def evaluate_model(model, tokenizer, dataset, device, batch_size, samples=None,


def keyword_to_itype(k):
return {"none": None, "int8": torch.int8, "fp8_e5m2": torch.float8_e5m2, "fp8_e4m3": torch.float8_e4m3fn}[k]
return {"none": None, "int8": qint8, "fp8_e5m2": qfloat8_e5m2, "fp8_e4m3": qfloat8_e4m3fn}[k]


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision import datasets, transforms
from transformers import AutoModel

from quanto import Calibration, QTensor, freeze, int4, quantize
from quanto import Calibration, QTensor, freeze, qint4, qint8, quantize


def test(model, device, test_loader):
Expand Down Expand Up @@ -60,7 +60,7 @@ def train(log_interval, model, device, train_loader, optimizer, epoch):


def keyword_to_itype(k):
return {"none": None, "int4": int4, "int8": torch.int8}[k]
return {"none": None, "int4": qint4, "int8": qint8}[k]


def main():
Expand Down
2 changes: 1 addition & 1 deletion quanto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.11"
__version__ = "0.0.13dev"

from .calibrate import *
from .library import *
Expand Down
16 changes: 8 additions & 8 deletions quanto/nn/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch

from ..tensor import QBitsTensor, QTensor, absmax_scale, qbitsdtype
from ..tensor import QBitsTensor, QTensor, absmax_scale, qint2, qint4, qint8, qtype
from .qmodule import QModuleMixin, register_qmodule


Expand All @@ -11,12 +11,12 @@

@register_qmodule(torch.nn.Linear)
class QLinear(QModuleMixin, torch.nn.Linear):
def __init__(self, *args, weights: torch.dtype = torch.int8, **kwargs):
def __init__(self, *args, weights: qtype = qint8, **kwargs):
super().__init__(*args, **kwargs)
self.weights = weights

@classmethod
def from_module(cls, module, weights=torch.int8, activations: Optional[torch.dtype] = None):
def from_module(cls, module, weights=qint8, activations: Optional[qtype] = None):
qmodule = cls(
module.in_features,
module.out_features,
Expand All @@ -36,12 +36,12 @@ def qweight(self):
if isinstance(self.weight, QTensor):
return self.weight
# Quantize the weights per-axis
if isinstance(self.weights, torch.dtype):
if self.weights == qint8:
wscale = absmax_scale(self.weight, axis=0)
return QTensor.quantize(self.weight, itype=self.weights, scale=wscale)
elif isinstance(self.weights, qbitsdtype):
return QBitsTensor.quantize(self.weight, itype=self.weights, axis=0)
raise ValueError("Invalid quantized weights type")
return QTensor.quantize(self.weight, qtype=self.weights, scale=wscale)
elif self.weights in (qint2, qint4):
return QBitsTensor.quantize(self.weight, qtype=self.weights, axis=0)
raise ValueError(f"Invalid quantized weights type {self.weights}")

def qforward(self, input: torch.Tensor) -> torch.Tensor:
if self.activations is not None and not isinstance(input, QTensor):
Expand Down
4 changes: 2 additions & 2 deletions quanto/nn/qmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def qforward(self, input: torch.Tensor) -> torch.Tensor:

def forward(self, input: torch.Tensor) -> torch.Tensor:
def maybe_requantize(t, scale):
if t.itype == self.activations and t.axis is None:
if t.qtype == self.activations and t.axis is None:
return t
return QTensor.quantize(t.dequantize(), itype=self.activations, scale=scale)
return QTensor.quantize(t.dequantize(), qtype=self.activations, scale=scale)

if self.activations is not None and isinstance(input, QTensor):
input = maybe_requantize(input, self.input_scale)
Expand Down
1 change: 1 addition & 0 deletions quanto/tensor/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .core import *
from .qtype import *
Loading
Loading