Skip to content

Commit

Permalink
feat(library): add quantize_symmetric op
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Feb 9, 2024
1 parent 1258c58 commit 06b8e33
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 0 deletions.
1 change: 1 addition & 0 deletions quanto/library/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,5 @@ def impl(*args, **kwargs):
return getattr(torch.ops.quanto_py, name)(*args, **kwargs)


define("quantize_symmetric", "(Tensor self, Tensor scale, ScalarType dtype) -> Tensor")
define("unpack", "(Tensor self, int bits) -> Tensor")
1 change: 1 addition & 0 deletions quanto/library/python/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .quantize import *
from .unpack import *
15 changes: 15 additions & 0 deletions quanto/library/python/quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch


def dtype_info(dtype):
info = torch.finfo if dtype.is_floating_point else torch.iinfo
return info(dtype)


@torch.library.impl("quanto_py::quantize_symmetric", "default")
def quantize_symmetric(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype):
info = dtype_info(dtype)
data = t / scale
if not dtype.is_floating_point:
data = torch.round(data)
return torch.clamp(data, min=info.min, max=info.max).to(dtype)
32 changes: 32 additions & 0 deletions test/library/test_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
import torch
from helpers import random_tensor


@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
@pytest.mark.parametrize("src_dtype", [torch.float32, torch.float16], ids=["fp32", "fp16"])
@pytest.mark.parametrize("dst_dtype", [torch.int8, torch.float8_e4m3fn], ids=["int8", "float8"])
@pytest.mark.parametrize("per_axis", [True, False], ids=["per-axis", "per-tensor"])
def test_quantize_symmetric(shape, src_dtype, dst_dtype, per_axis, device):
if device.type == "mps" and dst_dtype != torch.int8:
pytest.skip("float8 types are not supported on MPS device")
# Craft manually data and scale
if dst_dtype.is_floating_point:
data = random_tensor(shape, torch.float16).to(dst_dtype).to(device)
else:
data = torch.randint(-127, 127, shape, dtype=dst_dtype).to(device)
if per_axis:
scale_shape = (shape[0],) + (1,) * (len(shape) - 1)
else:
scale_shape = ()
scale = torch.rand(scale_shape, dtype=src_dtype).to(device)
# Dequantize to obtain a float tensor
t = data.to(src_dtype) * scale
qdata = torch.ops.quanto.quantize_symmetric(t, scale, dst_dtype)
assert qdata.dtype == dst_dtype
assert qdata.shape == shape
# float8 tensors direct comparison is not supported yet on CPU
if dst_dtype.is_floating_point:
assert torch.equal(qdata.to(torch.float16), data.to(torch.float16))
else:
assert torch.equal(qdata, data)

0 comments on commit 06b8e33

Please sign in to comment.