diff --git a/bench/library/benchmark.py b/bench/library/benchmark.py index a3e5159c..bad8cfa6 100644 --- a/bench/library/benchmark.py +++ b/bench/library/benchmark.py @@ -9,6 +9,17 @@ from quanto.library import disable_extensions +def get_quantize_symmetric_bench(src_dtype, dst_dtype, per_axis, device): + a = torch.rand([10240, 10240], dtype=src_dtype).to(device) + scale = torch.fill((10240,), 0.5) if per_axis else torch.tensor(0.5) + scale = scale.to(src_dtype).to(device) + + def bench_fn(): + return torch.ops.quanto.quantize_symmetric(a, scale, dst_dtype) + + return bench_fn + + def get_unpack_bench(bits, device): qmax = 2**bits a = torch.randint(0, qmax, [10240, 10240], dtype=torch.uint8).to(device) @@ -69,6 +80,9 @@ def elapsed_time(self, other): GET_BENCH_FUNCTIONS = { + "quantize_symmetric_fp32_int8_per_tensor": lambda device: get_quantize_symmetric_bench( + torch.float32, torch.int8, False, device + ), "unpack_2bit": lambda device: get_unpack_bench(2, device), "unpack_4bit": lambda device: get_unpack_bench(4, device), } @@ -89,7 +103,7 @@ def main(): device = torch.device("cpu") else: device = torch.device(args.device) - all_kernels = ["unpack_2bit", "unpack_4bit"] + all_kernels = GET_BENCH_FUNCTIONS.keys() kernels = all_kernels if args.kernel is None else [args.kernel] for kernel in kernels: get_bench_fn = GET_BENCH_FUNCTIONS[kernel] diff --git a/quanto/library/ext/cpp/__init__.py b/quanto/library/ext/cpp/__init__.py index f04e8bec..adc011fe 100644 --- a/quanto/library/ext/cpp/__init__.py +++ b/quanto/library/ext/cpp/__init__.py @@ -19,6 +19,7 @@ def ext(): _ext = load( name="quanto_cpp", sources=[ + f"{module_path}/quantize.cpp", f"{module_path}/unpack.cpp", f"{module_path}/pybind_module.cpp", ], @@ -27,6 +28,11 @@ def ext(): return _ext +@torch.library.impl("quanto_ext::quantize_symmetric", ["CPU"]) +def quantize_symmetric_cpp(t: torch.Tensor, scale: torch.Tensor, dtype: torch.Tensor.dtype): + return ext().quantize_symmetric(t, scale, dtype) + + @impl("quanto_ext::unpack", ["CPU", "CUDA"]) def unpack_cpp(t: torch.Tensor, bits: int): return ext().unpack(t, bits) diff --git a/quanto/library/ext/cpp/pybind_module.cpp b/quanto/library/ext/cpp/pybind_module.cpp index ee700e75..3b25bb03 100644 --- a/quanto/library/ext/cpp/pybind_module.cpp +++ b/quanto/library/ext/cpp/pybind_module.cpp @@ -1,7 +1,20 @@ #include +#include "quantize.h" #include "unpack.h" +// !IMPORTANT! Some python objects such as dtype, device, are not mapped to C++ types, +// and need to be explicitly converted using dedicated helpers before calling a C++ method. +// As a consequence, when an operation takes such an object as parameter, instead +// of creating a binding directly to the C++ method, you must create a binding to a +// lambda method that converts the unmapped types and calls the C++ method. +// See the binding of quantize_symmetric for instance. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("quantize_symmetric", + [](const torch::Tensor& t, const torch::Tensor& scale, py::object dtype) { + return quantize_symmetric(t, + scale, + torch::python::detail::py_object_to_dtype(dtype)); + }, "quantize_symmetric"); m.def("unpack", &unpack, "unpack"); } diff --git a/quanto/library/ext/cpp/quantize.cpp b/quanto/library/ext/cpp/quantize.cpp new file mode 100644 index 00000000..dfddfe24 --- /dev/null +++ b/quanto/library/ext/cpp/quantize.cpp @@ -0,0 +1,64 @@ +#include "quantize.h" +#include + + +template +torch::Tensor quantize_symmetric_per_tensor(const torch::Tensor& input, const torch::Tensor& scale) { + torch::Tensor output = torch::empty_like(input, c10::TensorOptions(c10::kChar).dtype(), LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto qdata = reinterpret_cast(output.data_ptr()); + auto numel = input.numel(); + const T* const data = input.data_ptr(); + float float_scale = scale.data_ptr()[0]; + float inv_scale = float_scale == 0 ? 1.0f : 1.0f / float_scale; + for (const auto i : c10::irange(numel)) { + int64_t qvalue = lrintf(std::nearbyint(data[i] * inv_scale)); + qvalue = std::max(-127LL, std::min(qvalue, 127LL)); + qdata[i] = static_cast(qvalue); + } + return output; +} + + +int get_scale_axis(const torch::Tensor& scale) { + int axis = -1; + auto scale_dims = scale.sizes(); + for (int i = 0; i < scale_dims.size(); ++i) { + if (scale_dims[i] != 1) { + axis = i; + } + } + return axis; +} + + +torch::Tensor quantize_symmetric_char(const torch::Tensor& input, + const torch::Tensor& scale) { + int axis = get_scale_axis(scale); + if (axis == -1) { + auto scale_dtype = scale.dtype(); + if (scale_dtype == at::ScalarType::Float) { + return quantize_symmetric_per_tensor(input, scale); + } + if (scale_dtype == at::ScalarType::Half) { + return quantize_symmetric_per_tensor(input, scale); + } + TORCH_CHECK(false, "Unsupported scale dtype:", scale_dtype) + } + TORCH_CHECK(false, "symmetric per-axis is not supported") +} + + +torch::Tensor quantize_symmetric(const torch::Tensor& input, + const torch::Tensor& scale, + at::ScalarType dtype) { + bool scalar_scale = (scale.sizes().size() == 0); + bool broadcastable_scale = (input.sizes().size() == scale.sizes().size()); + TORCH_CHECK(scalar_scale || broadcastable_scale, + "Quantization scale must be scalar or broadcastable to the base tensor.") + TORCH_CHECK((scale.dtype() == at::ScalarType::Float) || (scale.dtype() == at::ScalarType::Half), + "Quantization scale must be float or float16.") + if (dtype == at::ScalarType::Char) { + return quantize_symmetric_char(input, scale); + } + TORCH_CHECK_NOT_IMPLEMENTED(false, "quantize_symmetric not supported for ", dtype) +} diff --git a/quanto/library/ext/cpp/quantize.h b/quanto/library/ext/cpp/quantize.h new file mode 100644 index 00000000..63d93dca --- /dev/null +++ b/quanto/library/ext/cpp/quantize.h @@ -0,0 +1,5 @@ +#include + +torch::Tensor quantize_symmetric(const torch::Tensor& input, + const torch::Tensor& scale, + at::ScalarType dtype);