From c23f15503e72cbc79d0d5fce8a4288af5eb2d3dd Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Fri, 20 Sep 2024 07:43:32 +0000 Subject: [PATCH] refactor(marlin): prepare the introduciton of int4 kernel --- optimum/quanto/tensor/weights/marlin/__init__.py | 2 +- optimum/quanto/tensor/weights/marlin/fp8/__init__.py | 1 + optimum/quanto/tensor/weights/marlin/{ => fp8}/packed.py | 0 .../tensor/weights/marlin/{marlin.py => fp8/qbits.py} | 9 ++++++--- 4 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 optimum/quanto/tensor/weights/marlin/fp8/__init__.py rename optimum/quanto/tensor/weights/marlin/{ => fp8}/packed.py (100%) rename optimum/quanto/tensor/weights/marlin/{marlin.py => fp8/qbits.py} (97%) diff --git a/optimum/quanto/tensor/weights/marlin/__init__.py b/optimum/quanto/tensor/weights/marlin/__init__.py index 26612d14..e4db126c 100644 --- a/optimum/quanto/tensor/weights/marlin/__init__.py +++ b/optimum/quanto/tensor/weights/marlin/__init__.py @@ -1 +1 @@ -from .marlin import MarlinF8QBytesTensor +from .fp8 import * diff --git a/optimum/quanto/tensor/weights/marlin/fp8/__init__.py b/optimum/quanto/tensor/weights/marlin/fp8/__init__.py new file mode 100644 index 00000000..af55ec31 --- /dev/null +++ b/optimum/quanto/tensor/weights/marlin/fp8/__init__.py @@ -0,0 +1 @@ +from .qbits import * diff --git a/optimum/quanto/tensor/weights/marlin/packed.py b/optimum/quanto/tensor/weights/marlin/fp8/packed.py similarity index 100% rename from optimum/quanto/tensor/weights/marlin/packed.py rename to optimum/quanto/tensor/weights/marlin/fp8/packed.py diff --git a/optimum/quanto/tensor/weights/marlin/marlin.py b/optimum/quanto/tensor/weights/marlin/fp8/qbits.py similarity index 97% rename from optimum/quanto/tensor/weights/marlin/marlin.py rename to optimum/quanto/tensor/weights/marlin/fp8/qbits.py index 82ef9068..666b9721 100644 --- a/optimum/quanto/tensor/weights/marlin/marlin.py +++ b/optimum/quanto/tensor/weights/marlin/fp8/qbits.py @@ -16,12 +16,15 @@ import torch -from ...function import QuantizedLinearFunction -from ...qtype import qfloat8_e4m3fn, qtypes -from ..qbytes import WeightQBytesTensor +from ....function import QuantizedLinearFunction +from ....qtype import qfloat8_e4m3fn, qtypes +from ...qbytes import WeightQBytesTensor from .packed import MarlinF8PackedTensor, get_scale_perms +__all__ = ["MarlinF8QBytesTensor"] + + class MarlinF8QBytesLinearFunction(QuantizedLinearFunction): @staticmethod def forward(ctx, input, other, bias=None):