Skip to content

Commit

Permalink
fix: remove extension load on unsupported system
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jul 29, 2024
1 parent efdd147 commit 07a5652
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 13 deletions.
3 changes: 0 additions & 3 deletions optimum/quanto/library/extensions/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ def __init__(
self.build_directory = os.path.join(root_dir, "build")
self._lib = None

# There is no reason not to build ahead of runtime.
tmp = self.lib # noqa

@property
def lib(self):
if self._lib is None:
Expand Down
6 changes: 6 additions & 0 deletions optimum/quanto/library/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
import torch


if torch.cuda.is_available():
from .extensions.cuda import ext

# This is required to be able to access `torch.ops.quanto_ext.*` members defined in C++ through `TORCH_LIBRARY`.
_ = ext.lib

# This file contains the definitions of all operations under torch.ops.quanto


Expand Down
3 changes: 0 additions & 3 deletions optimum/quanto/nn/qlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@

import torch

# This is required to be able to access `torch.ops.quanto_ext.*` members defined in C++ through `TORCH_LIBRARY`.
from optimum.quanto.library.extensions.cuda import ext # noqa: F401

from ..tensor import Optimizer, QBytesTensor, qtype
from ..tensor.qbits.awq.qbits import AWQBitsTensor
from ..tensor.qbits.tinygemm.qbits import TinyGemmQBitsTensor
Expand Down
1 change: 0 additions & 1 deletion optimum/quanto/tensor/qbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.


import torch
from torch.autograd import Function

from .qtensor import QTensor
Expand Down
3 changes: 0 additions & 3 deletions optimum/quanto/tensor/weights/marlin/packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
import torch
from torch.utils import _pytree as pytree

# This is required to be able to access `torch.ops.quanto_ext.*` members defined in C++ through `TORCH_LIBRARY`.
from optimum.quanto.library.extensions.cuda import ext # noqa: F401


def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Expand Down
3 changes: 0 additions & 3 deletions test/library/test_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,6 @@ def test_gemm_fp16_int4(batch_size, tokens, in_features, out_features):
@pytest.mark.parametrize("in_features, out_features", [(256, 1024), (512, 2048)])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"])
def test_fp8_marlin(tokens, in_features, out_features, dtype):
# This is required to be able to access `torch.ops.quanto_ext.*` members defined in C++ through `TORCH_LIBRARY`.
from optimum.quanto.library.extensions.cuda import ext # noqa: F401

device = torch.device("cuda")
input_shape = (tokens, in_features)
inputs = torch.rand(input_shape, dtype=dtype, device=device)
Expand Down

0 comments on commit 07a5652

Please sign in to comment.