From 5fc63e680d5350cfbffdc650b068324671e8c14b Mon Sep 17 00:00:00 2001 From: David Corvoysier Date: Thu, 3 Oct 2024 10:09:40 +0200 Subject: [PATCH] refactor(library): remove unpack indirection --- optimum/quanto/library/__init__.py | 3 +- optimum/quanto/library/extensions/README.md | 42 ++++++++--- .../quanto/library/extensions/cpp/__init__.py | 2 +- .../library/extensions/cuda/__init__.py | 2 +- .../quanto/library/extensions/mps/__init__.py | 2 +- optimum/quanto/library/ops.py | 70 ------------------- optimum/quanto/library/python/README.md | 18 ----- optimum/quanto/library/python/__init__.py | 15 ---- optimum/quanto/library/{python => }/unpack.py | 5 +- test/library/test_unpack.py | 9 +-- 10 files changed, 44 insertions(+), 124 deletions(-) delete mode 100644 optimum/quanto/library/ops.py delete mode 100644 optimum/quanto/library/python/README.md delete mode 100644 optimum/quanto/library/python/__init__.py rename optimum/quanto/library/{python => }/unpack.py (93%) diff --git a/optimum/quanto/library/__init__.py b/optimum/quanto/library/__init__.py index 0f650b5d..d457d139 100644 --- a/optimum/quanto/library/__init__.py +++ b/optimum/quanto/library/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from .extensions import * -from .ops import * -from .python import * from .qbytes_mm import * from .quantize import * +from .unpack import * diff --git a/optimum/quanto/library/extensions/README.md b/optimum/quanto/library/extensions/README.md index f2b59a97..f3107052 100644 --- a/optimum/quanto/library/extensions/README.md +++ b/optimum/quanto/library/extensions/README.md @@ -1,8 +1,6 @@ # Quanto library extensions -This folder contains the implementations of all `quanto_ext::` operations. - -This namespace corresponds to the device-specifc optimized implementations of quanto operations. +This folder contains device-specific `quanto::` operations. Implementations can be provided as part of: @@ -10,14 +8,42 @@ Implementations can be provided as part of: - the CUDA extension under `cuda`, - the Metal Performance Shader extension under `mps`. -The operations are defined in `library/ops.py`. -To provide an implementation for specific device types, use the following syntax: +To provide a device-specific implementation of an operation that already has a default implementation (such as unpack), use the following syntax: ```python -@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"]) +@torch.library.impl("quanto::unpack", ["CPU", "CUDA"]) def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: - return ext().unpack(t, bits) + return ext.unpack(t, bits) +``` + +To declare a new device-specific operation, you need to add it to the library: + +```python +torch.library.define( + "quanto::gemm_f16i4", + "(Tensor input," + " Tensor other," + " Tensor other_scale," + " Tensor other_shift," + " int group_size)" + " -> Tensor", +) ``` -Please refer to each extension folder to see how to add the actual implementation. +Then you can provide its implementation: + +```python +@torch.library.impl("quanto::gemm_f16i4", ["CUDA"]) +def gemm_f16i4( + input: torch.Tensor, + other: torch.Tensor, + scales: torch.Tensor, + shift: torch.Tensor, + group_size: int, +) -> torch.Tensor: + ... +``` + + +Please refer to each extension folder for examples. diff --git a/optimum/quanto/library/extensions/cpp/__init__.py b/optimum/quanto/library/extensions/cpp/__init__.py index 4a0d6c04..f4325494 100644 --- a/optimum/quanto/library/extensions/cpp/__init__.py +++ b/optimum/quanto/library/extensions/cpp/__init__.py @@ -30,6 +30,6 @@ ) -@torch.library.impl("quanto_ext::unpack", ["CPU"]) +@torch.library.impl("quanto::unpack", ["CPU"]) def unpack_cpp(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) diff --git a/optimum/quanto/library/extensions/cuda/__init__.py b/optimum/quanto/library/extensions/cuda/__init__.py index d44684db..7113d661 100644 --- a/optimum/quanto/library/extensions/cuda/__init__.py +++ b/optimum/quanto/library/extensions/cuda/__init__.py @@ -71,7 +71,7 @@ def get_max_cuda_arch(): ) -@torch.library.impl("quanto_ext::unpack", ["CUDA"]) +@torch.library.impl("quanto::unpack", ["CUDA"]) def unpack_cuda(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) diff --git a/optimum/quanto/library/extensions/mps/__init__.py b/optimum/quanto/library/extensions/mps/__init__.py index f2aaa71a..46bee696 100644 --- a/optimum/quanto/library/extensions/mps/__init__.py +++ b/optimum/quanto/library/extensions/mps/__init__.py @@ -30,6 +30,6 @@ ) -@torch.library.impl("quanto_ext::unpack", "MPS") +@torch.library.impl("quanto::unpack", "MPS") def unpack_mps(t: torch.Tensor, bits: int): return ext.lib.unpack(t, bits) diff --git a/optimum/quanto/library/ops.py b/optimum/quanto/library/ops.py deleted file mode 100644 index b0051f7a..00000000 --- a/optimum/quanto/library/ops.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from contextlib import contextmanager - -import torch - - -# This file contains the definitions of all operations under torch.ops.quanto - - -_ext_enabled = True - - -@contextmanager -def disable_extensions(): - """Disable quanto extensions (debug)""" - try: - global _ext_enabled - _ext_enabled = False - yield - finally: - _ext_enabled = True - - -def define(name, schema): - """Define a new quanto operation. - - The operation will actually be defined in three libraries: - - the top-level quanto library as quanto::, - - the quanto python library as quanto_py::, - - the quanto extension library as quanto_ext::. - - Only the implementations for the python and extension library need - to be provided: the top-level implementation for the operation is - provided when calling this method and simply routes the calls towards - either the python or extension implementations based on the selected - mode. - """ - for libname in ["quanto", "quanto_py", "quanto_ext"]: - torch.library.define(f"{libname}::{name}", schema) - - # Provide the inplementation for all dispatch keys in the main library - @torch.library.impl(f"quanto::{name}", "default") - def impl(*args, **kwargs): - if _ext_enabled: - try: - return getattr(torch.ops.quanto_ext, name)(*args, **kwargs) - except Exception as e: - if isinstance(e, NotImplementedError): - message = f"No optimized kernel found for quanto::{name}." - else: - message = f"An exception was raised while calling the optimized kernel for quanto::{name}: {e}" - warnings.warn(message + " Falling back to default implementation.") - return getattr(torch.ops.quanto_py, name)(*args, **kwargs) - - -define("unpack", "(Tensor self, int bits) -> Tensor") diff --git a/optimum/quanto/library/python/README.md b/optimum/quanto/library/python/README.md deleted file mode 100644 index 8c33eda0..00000000 --- a/optimum/quanto/library/python/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Quanto library python/pytorch operations - -This folder contains the implementations of all `quanto_py::` operations. - -This namespace corresponds to the default, python-only implementations of quanto operations. - -The operations are defined in `library/ops.py`. - -To provide an implementation for an operation, use the following syntax: - -```python -@torch.library.impl("quanto_py::unpack", "default") -def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: - ... -``` - -The implementation **must** support all device types. This is true if it -is a composition of built-in PyTorch operators. diff --git a/optimum/quanto/library/python/__init__.py b/optimum/quanto/library/python/__init__.py deleted file mode 100644 index bb4db893..00000000 --- a/optimum/quanto/library/python/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .unpack import * diff --git a/optimum/quanto/library/python/unpack.py b/optimum/quanto/library/unpack.py similarity index 93% rename from optimum/quanto/library/python/unpack.py rename to optimum/quanto/library/unpack.py index f671b0b2..74e3e4be 100644 --- a/optimum/quanto/library/python/unpack.py +++ b/optimum/quanto/library/unpack.py @@ -15,7 +15,10 @@ import torch -@torch.library.impl("quanto_py::unpack", "default") +torch.library.define("quanto::unpack", "(Tensor self, int bits) -> Tensor") + + +@torch.library.impl("quanto::unpack", "default") def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: """ Un-Pack int4 / int2 weights (packed in a uint8) into a torch.uint8 tensor diff --git a/test/library/test_unpack.py b/test/library/test_unpack.py index 6cf08db3..125de1fc 100644 --- a/test/library/test_unpack.py +++ b/test/library/test_unpack.py @@ -12,24 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import nullcontext import pytest import torch -from optimum.quanto.library import disable_extensions from optimum.quanto.tensor.packed import pack_weights @pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"]) @pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"]) -@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"]) -def test_unpack(bits, shape, use_ext, device): +def test_unpack(bits, shape, device): qmax = 2**bits a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device) packed_a = pack_weights(a, bits) - context = nullcontext() if use_ext else disable_extensions() - with context: - unpacked_a = torch.ops.quanto.unpack(packed_a, bits) + unpacked_a = torch.ops.quanto.unpack(packed_a, bits) assert unpacked_a.dtype == torch.uint8 assert torch.equal(unpacked_a, a)