Skip to content

Commit

Permalink
test(compile): still not working with pt 2.3.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed May 3, 2024
1 parent 19d3e88 commit 6e44e96
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions test/tensor/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from helpers import random_tensor, torch_min_version

from quanto import QBytesTensor, qint8, quantize_activation
from quanto import QBytesTensor, absmax_scale, qint8, quantize_activation


def compile_for_device(f, device):
Expand All @@ -27,7 +27,7 @@ def compile_for_device(f, device):
return torch.compile(f, backend=backend)


@torch_min_version("2.3.0")
@torch_min_version("2.4.0")
@pytest.mark.parametrize("input_shape", [(2, 10), (10, 32, 32)])
@pytest.mark.parametrize("qtype", [qint8], ids=["qint8"])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=["fp32", "fp16", "bf16"])
Expand All @@ -37,7 +37,8 @@ def test_compile_quantize_tensor(input_shape, qtype, dtype, device):
a = random_tensor(input_shape, dtype=dtype).to(device)

def f(x, qtype):
return quantize_activation(x, qtype=qtype)
scale = absmax_scale(x)
return quantize_activation(x, qtype=qtype, scale=scale)

compiled_f = compile_for_device(f, device)
qa = compiled_f(a, qtype)
Expand All @@ -48,20 +49,18 @@ def f(x, qtype):


@torch_min_version("2.3.0")
@pytest.mark.parametrize("qtensor_input", [True, False], ids=["qtensor-input", "tensor-input"])
def test_compile_qtensor_to(qtensor_input, device):
def test_compile_qtensor_to(device):
input_shape = (10, 32, 32)
a = random_tensor(input_shape).to(device)

def f(x, dtype):
qx = x if isinstance(x, QBytesTensor) else quantize_activation(x)
return qx.to(dtype)
return x.to(dtype)

compiled_f = compile_for_device(f, device)

if qtensor_input:
a = quantize_activation(a)
qa = compiled_f(a, torch.float16)
assert isinstance(qa, QBytesTensor)
assert qa.qtype == qint8
assert qa._scale.dtype == torch.float16
scale = absmax_scale(a)
qa = quantize_activation(a, qtype=qint8, scale=scale)
cqa = compiled_f(qa, torch.float16)
assert isinstance(cqa, QBytesTensor)
assert cqa.qtype == qint8
assert cqa._scale.dtype == torch.float16

0 comments on commit 6e44e96

Please sign in to comment.