Skip to content

Commit

Permalink
refactor(qbits): remove subdirectory
Browse files Browse the repository at this point in the history
Also avoid exporting AWQ and TinyGemm classes at the top level.
  • Loading branch information
dacorvo committed Sep 20, 2024
1 parent 4f3af18 commit b011276
Show file tree
Hide file tree
Showing 13 changed files with 20 additions and 21 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import torch
from torch.autograd import Function

from ....function import QuantizedLinearFunction
from ....grouped import group, ungroup
from ....qtype import qtypes
from ...function import QuantizedLinearFunction
from ...grouped import group, ungroup
from ...qtype import qtypes
from ..qbits import WeightQBitsTensor
from .packed import AWQPackedTensor, AWQPacking

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
from packaging import version
from torch.autograd import Function

from ...function import QuantizedLinearFunction
from ...grouped import grouped_shape
from ...packed import PackedTensor
from ...qbits import QBitsTensor
from ...qtensor import qfallback
from ...qtype import qint2, qint4, qtype, qtypes
from ..function import QuantizedLinearFunction
from ..grouped import grouped_shape
from ..packed import PackedTensor
from ..qbits import QBitsTensor
from ..qtensor import qfallback
from ..qtype import qint2, qint4, qtype, qtypes


__all__ = ["WeightQBitsTensor"]
Expand Down
3 changes: 0 additions & 3 deletions optimum/quanto/tensor/weights/qbits/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import torch
from torch.autograd import Function

from ....function import QuantizedLinearFunction
from ....grouped import group, ungroup
from ....qtype import qtypes
from ...qbits import WeightQBitsTensor
from ...function import QuantizedLinearFunction
from ...grouped import group, ungroup
from ...qtype import qtypes
from ..qbits import WeightQBitsTensor
from .packed import TinyGemmPackedTensor


Expand Down
2 changes: 1 addition & 1 deletion test/library/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from helpers import assert_similar, random_tensor

from optimum.quanto import AWQPackedTensor, AWQPacking
from optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking
from optimum.quanto.tensor.weights.marlin.packed import get_scale_perms, pack_fp8_as_int32


Expand Down
2 changes: 1 addition & 1 deletion test/tensor/weights/optimized/test_awq_packed_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
from helpers import device_eq

from optimum.quanto import AWQPackedTensor, AWQPacking
from optimum.quanto.tensor.weights.awq import AWQPackedTensor, AWQPacking


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from helpers import device_eq, random_weight_qbits_tensor

from optimum.quanto import qint4
from optimum.quanto.tensor.weights import AWQWeightQBitsTensor, WeightQBitsTensor
from optimum.quanto.tensor.weights.awq import AWQWeightQBitsTensor
from optimum.quanto.tensor.weights import WeightQBitsTensor


@pytest.mark.skipif(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from helpers import device_eq
from packaging import version

from optimum.quanto import TinyGemmPackedTensor
from optimum.quanto.tensor.weights.tinygemm import TinyGemmPackedTensor


@pytest.mark.skip_device("mps") # Only available with pytorch 2.4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from packaging import version

from optimum.quanto import qint4
from optimum.quanto.tensor.weights import TinyGemmWeightQBitsTensor, WeightQBitsTensor
from optimum.quanto.tensor.weights.tinygemm import TinyGemmWeightQBitsTensor
from optimum.quanto.tensor.weights import WeightQBitsTensor


@pytest.mark.skip_device("mps") # Only available with pytorch 2.4
Expand Down

0 comments on commit b011276

Please sign in to comment.