Skip to content

Commit

Permalink
refactor(library): remove unpack indirection
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Oct 3, 2024
1 parent 9a76f25 commit 5fc63e6
Show file tree
Hide file tree
Showing 10 changed files with 44 additions and 124 deletions.
3 changes: 1 addition & 2 deletions optimum/quanto/library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from .extensions import *
from .ops import *
from .python import *
from .qbytes_mm import *
from .quantize import *
from .unpack import *
42 changes: 34 additions & 8 deletions optimum/quanto/library/extensions/README.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,49 @@
# Quanto library extensions

This folder contains the implementations of all `quanto_ext::` operations.

This namespace corresponds to the device-specifc optimized implementations of quanto operations.
This folder contains device-specific `quanto::` operations.

Implementations can be provided as part of:

- the generic C++ pytorch extension under `cpp`,
- the CUDA extension under `cuda`,
- the Metal Performance Shader extension under `mps`.

The operations are defined in `library/ops.py`.

To provide an implementation for specific device types, use the following syntax:
To provide a device-specific implementation of an operation that already has a default implementation (such as unpack), use the following syntax:

```python
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"])
@torch.library.impl("quanto::unpack", ["CPU", "CUDA"])
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
return ext().unpack(t, bits)
return ext.unpack(t, bits)
```

To declare a new device-specific operation, you need to add it to the library:

```python
torch.library.define(
"quanto::gemm_f16i4",
"(Tensor input,"
" Tensor other,"
" Tensor other_scale,"
" Tensor other_shift,"
" int group_size)"
" -> Tensor",
)
```

Please refer to each extension folder to see how to add the actual implementation.
Then you can provide its implementation:

```python
@torch.library.impl("quanto::gemm_f16i4", ["CUDA"])
def gemm_f16i4(
input: torch.Tensor,
other: torch.Tensor,
scales: torch.Tensor,
shift: torch.Tensor,
group_size: int,
) -> torch.Tensor:
...
```


Please refer to each extension folder for examples.
2 changes: 1 addition & 1 deletion optimum/quanto/library/extensions/cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
)


@torch.library.impl("quanto_ext::unpack", ["CPU"])
@torch.library.impl("quanto::unpack", ["CPU"])
def unpack_cpp(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
2 changes: 1 addition & 1 deletion optimum/quanto/library/extensions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def get_max_cuda_arch():
)


@torch.library.impl("quanto_ext::unpack", ["CUDA"])
@torch.library.impl("quanto::unpack", ["CUDA"])
def unpack_cuda(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)

Expand Down
2 changes: 1 addition & 1 deletion optimum/quanto/library/extensions/mps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
)


@torch.library.impl("quanto_ext::unpack", "MPS")
@torch.library.impl("quanto::unpack", "MPS")
def unpack_mps(t: torch.Tensor, bits: int):
return ext.lib.unpack(t, bits)
70 changes: 0 additions & 70 deletions optimum/quanto/library/ops.py

This file was deleted.

18 changes: 0 additions & 18 deletions optimum/quanto/library/python/README.md

This file was deleted.

15 changes: 0 additions & 15 deletions optimum/quanto/library/python/__init__.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import torch


@torch.library.impl("quanto_py::unpack", "default")
torch.library.define("quanto::unpack", "(Tensor self, int bits) -> Tensor")


@torch.library.impl("quanto::unpack", "default")
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor:
"""
Un-Pack int4 / int2 weights (packed in a uint8) into a torch.uint8 tensor
Expand Down
9 changes: 2 additions & 7 deletions test/library/test_unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import pytest
import torch

from optimum.quanto.library import disable_extensions
from optimum.quanto.tensor.packed import pack_weights


@pytest.mark.parametrize("bits", [2, 4], ids=["int2", "int4"])
@pytest.mark.parametrize("shape", [(12,), (32, 32)], ids=["vector", "matrix"])
@pytest.mark.parametrize("use_ext", [True, False], ids=["ext", "no-ext"])
def test_unpack(bits, shape, use_ext, device):
def test_unpack(bits, shape, device):
qmax = 2**bits
a = torch.randint(0, qmax, shape, dtype=torch.uint8).to(device)
packed_a = pack_weights(a, bits)
context = nullcontext() if use_ext else disable_extensions()
with context:
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
unpacked_a = torch.ops.quanto.unpack(packed_a, bits)
assert unpacked_a.dtype == torch.uint8
assert torch.equal(unpacked_a, a)

0 comments on commit 5fc63e6

Please sign in to comment.