-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(library): remove unpack indirection
- Loading branch information
Showing
10 changed files
with
44 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,49 @@ | ||
# Quanto library extensions | ||
|
||
This folder contains the implementations of all `quanto_ext::` operations. | ||
|
||
This namespace corresponds to the device-specifc optimized implementations of quanto operations. | ||
This folder contains device-specific `quanto::` operations. | ||
|
||
Implementations can be provided as part of: | ||
|
||
- the generic C++ pytorch extension under `cpp`, | ||
- the CUDA extension under `cuda`, | ||
- the Metal Performance Shader extension under `mps`. | ||
|
||
The operations are defined in `library/ops.py`. | ||
|
||
To provide an implementation for specific device types, use the following syntax: | ||
To provide a device-specific implementation of an operation that already has a default implementation (such as unpack), use the following syntax: | ||
|
||
```python | ||
@torch.library.impl("quanto_ext::unpack", ["CPU", "CUDA"]) | ||
@torch.library.impl("quanto::unpack", ["CPU", "CUDA"]) | ||
def unpack(packed: torch.Tensor, bits: int) -> torch.Tensor: | ||
return ext().unpack(t, bits) | ||
return ext.unpack(t, bits) | ||
``` | ||
|
||
To declare a new device-specific operation, you need to add it to the library: | ||
|
||
```python | ||
torch.library.define( | ||
"quanto::gemm_f16i4", | ||
"(Tensor input," | ||
" Tensor other," | ||
" Tensor other_scale," | ||
" Tensor other_shift," | ||
" int group_size)" | ||
" -> Tensor", | ||
) | ||
``` | ||
|
||
Please refer to each extension folder to see how to add the actual implementation. | ||
Then you can provide its implementation: | ||
|
||
```python | ||
@torch.library.impl("quanto::gemm_f16i4", ["CUDA"]) | ||
def gemm_f16i4( | ||
input: torch.Tensor, | ||
other: torch.Tensor, | ||
scales: torch.Tensor, | ||
shift: torch.Tensor, | ||
group_size: int, | ||
) -> torch.Tensor: | ||
... | ||
``` | ||
|
||
|
||
Please refer to each extension folder for examples. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters