diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index f08ba7aa72..62aadae57c 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -8,7 +8,7 @@ run_tests, ) -from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout +from torchao.dtypes import CutlassInt4PackedLayout, Int4XPULayout, Int4CPULayout, SemiSparseLayout from torchao.quantization import ( float8_weight_only, int4_weight_only, @@ -20,6 +20,7 @@ from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, is_sm_at_least_89, ) @@ -46,6 +47,18 @@ def get_quantization_functions( zero_point_domain=ZeroPointDomain.INT, ) ) + elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6: + base_functions.append( + int4_weight_only(group_size=32, layout=Int4XPULayout()) + ) + if int4_zp_int: + base_functions.append( + int4_weight_only( + group_size=32, + layout=Int4XPULayout(), + zero_point_domain=ZeroPointDomain.INT, + ) + ) else: base_functions.append(int4_weight_only(group_size=32)) if device == "cuda": diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 1087db8cf8..7449ea913f 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -18,7 +18,7 @@ from torch._inductor.utils import run_and_get_code import torchao -from torchao.dtypes import Int4CPULayout, TensorCoreTiledLayout +from torchao.dtypes import Int4CPULayout, Int4XPULayout, TensorCoreTiledLayout from torchao.dtypes.utils import is_device from torchao.quantization import safe_int_mm from torchao.quantization.autoquant import ( @@ -139,6 +139,11 @@ def _int4wo_api(mod): mod, int4_weight_only(layout=Int4CPULayout()), set_inductor_config=False ) unwrap_tensor_subclass(mod) + elif is_device(next(mod.parameters()).device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7: + quantize_( + mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False + ) + unwrap_tensor_subclass(mod) elif TORCH_VERSION_AT_LEAST_2_4: quantize_(mod, int4_weight_only(), set_inductor_config=False) if not TORCH_VERSION_AT_LEAST_2_5: @@ -1079,6 +1084,8 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype): layout_list = [] if device == "cpu" and TORCH_VERSION_AT_LEAST_2_6: layout_list.append(Int4CPULayout()) + elif device == "xpu" and TORCH_VERSION_AT_LEAST_2_6: + layout_list.append(Int4XPULayout()) else: for inner_k_tiles in [4, 2]: layout_list.append(TensorCoreTiledLayout(inner_k_tiles=inner_k_tiles)) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 102e76cb1a..fb3f549c07 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -33,6 +33,7 @@ TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, + TORCH_VERSION_AT_LEAST_2_7, is_fbcode, ) @@ -130,7 +131,8 @@ def _groupwise_affine_quantize_tensor_from_qparams( ) if TORCH_VERSION_AT_LEAST_2_5: - if not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + if (not (is_device(w.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)) \ + and ((not is_device(w.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7)): w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) return w_int4x8 @@ -739,8 +741,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) if TORCH_VERSION_AT_LEAST_2_5: input_tmp = input - if not ( - is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6 + if (not ( + is_device(input.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) + and (not is_device(input.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7) ): input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 9cbd4cd2a0..4e25a0a505 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -16,6 +16,7 @@ BlockSparseLayout, CutlassInt4PackedLayout, Int4CPULayout, + Int4XPULayout, MarlinQQQLayout, MarlinQQQTensor, MarlinSparseLayout, @@ -52,4 +53,5 @@ "MarlinQQQLayout", "Int4CPULayout", "CutlassInt4PackedLayout", + "Int4XPULayout", ] diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 7cf375feb4..a2d96e48da 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -7,6 +7,9 @@ from .int4_cpu_layout import ( Int4CPULayout, ) +from .int4_xpu_layout import ( + Int4XPULayout, +) from .marlin_qqq_tensor import ( MarlinQQQLayout, MarlinQQQTensor, @@ -36,4 +39,5 @@ "MarlinQQQTensor", "to_marlinqqq_quantized_intx", "CutlassInt4PackedLayout", + "Int4XPULayout" ] diff --git a/torchao/dtypes/uintx/int4_xpu_layout.py b/torchao/dtypes/uintx/int4_xpu_layout.py new file mode 100644 index 0000000000..65e0bf6f1a --- /dev/null +++ b/torchao/dtypes/uintx/int4_xpu_layout.py @@ -0,0 +1,283 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List + +import torch +from torch.utils._python_dispatch import return_and_correct_aliasing + +from torchao.dtypes.affine_quantized_tensor import register_layout +from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_7, + fill_defaults, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Int4XPULayout(Layout): + """Only for PyTorch version at least 2.7""" + + pass + + +@register_layout(Int4XPULayout) +class Int4XPUAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl for int4 XPU layout for affine quantized tensor, this is for int4 only, + used by tinygemm kernels `_weight_int4pack_mm_xpu` and `_weight_int4pack_mm_with_zeros_and_scales` + It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 2-d tensor of + dimension: [n][k / 8] (int32 dtype) + (unpacked Tensor shape is n * k) + Note: we also pack scale and zero point together here for tinygemm kernel + Note: technically Int4 CPU layout should be the layout for the underlying packed weight + (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used + in plain layout, we just created a layout for AQT right now, this could be improved if we split out + int4 aqt into a separate tensor subclass + fields: + packed_weight (torch.Tensor): the 2-d packed tensor in a Int4 CPU layout + [Optional] scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor + [Optional] scale (torch.Tensor): scale tensors, should be the same dtype of packed weight + [Optional] zeros (torch.Tensor): can be of the same dtype of packed weight or different dtype + """ + + def __new__( + cls, + packed_weight: torch.Tensor, + scale_and_zero: Union[torch.Tensor, List[torch.Tensor]], + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scale_and_zero: Union[torch.Tensor, List[torch.Tensor]], + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scale_and_zero = scale_and_zero + self.transposed = False + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes + return cls(packed_weight, scale_and_zero, transposed, _layout) + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Int4XPULayout) + + if TORCH_VERSION_AT_LEAST_2_7: + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack_xpu( + int_data, + 1, # TODO:remove + ) + else: + assert ( + False + ), "INT4 not supported on XPU until 2.7" + + scale = scale.reshape(int_data.shape[0], -1) + zero_point = zero_point.reshape(int_data.shape[0], -1) + if (scale.dtype == zero_point.dtype) + from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) + else: + scale_and_zero = [scale, zero_point] + return cls(packed_weight, scale_and_zero, False, _layout) + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + device = kwargs["device"] + if not is_device(torch.device(self.device).type, device): + raise ValueError( + f"Int4XPUAQTTensorImpl does not support conversion from {self.device} to {device}" + ) + return self.__class__( + self.packed_weight.to(device), + self.scale_and_zero.to(device), + self.transposed, + self._layout, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scale_and_zero), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + """we don't need to repack the weight and just rely on external + shape being changed and record the status of transpose/no-transpose + """ + transposed = Int4XPUAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) + return return_and_correct_aliasing(func, args, kwargs, transposed) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim == 0: + int_data, scale, zero_point = self.get_plain() + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + elif dim == 1: + int_data, scale, zero_point = self.get_plain() + assert step == 1, "Only step == 1 is supported in slicing right now" + data_len = int_data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + int_data = aten.slice.Tensor(int_data, dim, start, end, step) + # this is to handle padding + int_data = self._layout.post_process(int_data) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) + sliced = self.from_plain(int_data, scale, zero_point, self._layout) + return sliced + else: + raise NotImplementedError( + f"Int4XPUAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"Int4XPUAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + from torchao.quantization.quant_primitives import ( + ZeroPointDomain, + quantize_affine, + ) + is_interger_zp = isinstance(self.scale_and_zero, List) + if is_interger_zp: + scale, zero = self.scale_and_zero + else: + from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) + + cur_shape = self.shape + assert len(cur_shape) == 2 + original_shape = (cur_shape[0], cur_shape[1] * 2) + eye_shape = original_shape[1] + groupsize = int(original_shape[1] / scale.shape[-2]) + block_size = (1, groupsize) + device = self.device + original_dtype = torch.bfloat16 + target_dtype = torch.int32 + quant_min = 0 + quant_max = 15 + assert len(block_size) == 2 and block_size[0] == 1 + if is_interger_zp: + zero_point_domain = ZeroPointDomain.INT + dequantized = torch.ops.aten._weight_int4pack_mm_with_scale_and_zeros( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + scale, + zero + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + else: + zero_point_domain = ZeroPointDomain.FLOAT + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) + dequantized = dequantized.t().contiguous() + # TODO: move this to `unpack_tinygemm_scales_and_zeros`? + scale = scale.reshape(scale.shape[:-1]).contiguous() + zero = zero.reshape(zero.shape[:-1]).contiguous() + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return int_data, scale, zero + + def get_layout(self) -> Layout: + return self._layout diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md index 1bdbcd96e1..b4f4f2dc5b 100644 --- a/torchao/prototype/hqq/README.md +++ b/torchao/prototype/hqq/README.md @@ -9,7 +9,7 @@ The kernel fuses two ops: Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. -> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. +> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. > The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel. ### Usage @@ -83,7 +83,7 @@ Initial benchmarking (on `A6000`) demonstrates promising results, scaling well f - Times are in `ms`, see `benchmarks/benchmark_hqq.py`. - `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). -- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm` or `torch.ops.aten._weight_int4pack_mm_for_cpu` or `torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros`(depend on zero points data types). Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. GPU details: diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 24c0efbf82..f82a380b37 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -12,7 +12,8 @@ from torch import Tensor, nn from torchao.dtypes.utils import is_device -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, \ + TORCH_VERSION_AT_LEAST_2_7 class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -166,6 +167,10 @@ def process_hqq_quants(self, W_q, meta): self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( W_q_torch, self.inner_k_tiles ) + if is_device(W_q.device.type, "Xpu") and TORCH_VERSION_AT_LEAST_2_7: + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_xpu( + W_q_torch, self.inner_k_tiles + ) else: self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( W_q_torch, self.inner_k_tiles @@ -205,7 +210,7 @@ def hqq_quants_to_torch_quants( .contiguous() ) if TORCH_VERSION_AT_LEAST_2_5: - if not is_device(W_q.device.type, "cpu"): + if not is_device(W_q.device.type, "cpu") and not is_device(W_q.device.type, "xpu"): W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8) # group_dequantize_tensor_from_qparams @@ -242,6 +247,11 @@ def matmul(self, x): c = torch.ops.aten._weight_int4pack_mm_for_cpu( x, self.weight_int4pack, self.groupsize, self.scales_and_zeros ) + if is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7 \ + and not isinstance(self.scales_and_zeros, torch.Tensor): + c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) else: c = torch.ops.aten._weight_int4pack_mm( x, self.weight_int4pack, self.groupsize, self.scales_and_zeros diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index cb7c8d0481..156c995327 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -546,6 +546,14 @@ def linear_forward_int4( groupsize, scales_and_zeros.to(scales_precision), ).to(dtype=x.dtype) + elif is_device(x.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7: + c = torch.ops.aten._weight_int4pack_mm_with_scales_and_zeros( + x.to(precision), + weight_int4pack, + groupsize, + scales_and_zeros[0], + scales_and_zeros[1], + ).to(dtype=x.dtype) else: c = torch.ops.aten._weight_int4pack_mm( x.to(precision), diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 9715d99e08..450d89421b 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -16,7 +16,8 @@ quant_int8_dynamic_per_token_linear, unpack_tinygemm_scales_and_zeros, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple +from torchao.utils import TORCH_VERSION_AT_LEAST_2_6, find_multiple,\ + TORCH_VERSION_AT_LEAST_2_7 __all__ = [ "Int8DynamicallyQuantizedLinearWeight", @@ -466,6 +467,13 @@ def _quantized_op(act_mat, w_qtensor, bias): w_qtensor.groupsize, w_qtensor.scales_and_zeros, ) + elif is_device(act_mat.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7: + y = aten._weight_int4pack_mm_with_scales_and_zeros( + act_mat.contiguous(), + w_qtensor.int_data, + w_qtensor.groupsize, + w_qtensor.scales_and_zeros, + ) else: y = aten._weight_int4pack_mm( act_mat.contiguous(), @@ -622,6 +630,10 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): int_data = aten._convert_weight_to_int4pack_for_cpu( input_int4x8, inner_k_tiles ) + if is_device(input_float.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7: + int_data = aten._convert_weight_to_int4pack_for_xpu( + input_int4x8, inner_k_tiles + ) else: int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles) return int_data, scales_and_zeros, False, groupsize, inner_k_tiles diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 74c136ad00..98dfe57a73 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -20,7 +20,8 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6,\ + TORCH_VERSION_AT_LEAST_2_7 __all__ = [ "compute_error", @@ -399,7 +400,8 @@ def groupwise_affine_quantize_tensor_from_qparams( zero_point_domain=zero_point_domain, ) if TORCH_VERSION_AT_LEAST_2_5 and w.shape[-1] > 1: - if not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6): + if (not (is_device(int_data.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6)) \ + and (not (is_device(int_data.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7)): int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) return int_data @@ -419,6 +421,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( TORCH_VERSION_AT_LEAST_2_5 and (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not (is_device(w_int4x8.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6) + and not (is_device(w_int4x8.device.type, "xpu") and TORCH_VERSION_AT_LEAST_2_7) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4