From 7da572d738b15c0a2a7d7cb2299475062e4dac0c Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 28 Oct 2024 16:00:00 +0000 Subject: [PATCH] fix FP8/INT8 + add detailed tests --- gemlite/core.py | 40 ++- .../gemm_A16fWnO16f_int32packing.py | 1 + .../gemm_splitK_A16fWnO16f_int32packing.py | 1 + .../gemv_A16fWnO16f_int32packing.py | 5 +- .../gemv_revsplitK_A16fWnO16f_int32packing.py | 15 +- tests/test_gemlitelineartriton.py | 315 ++++++++++++++++++ 6 files changed, 357 insertions(+), 20 deletions(-) create mode 100644 tests/test_gemlitelineartriton.py diff --git a/gemlite/core.py b/gemlite/core.py index b0d8122..aea3605 100644 --- a/gemlite/core.py +++ b/gemlite/core.py @@ -175,7 +175,7 @@ def get_closest_m(M): # Triton _GROUP_SIZE_WARNED = False; class GemLiteLinearTriton(torch.nn.Module): - SUPPORTED_BITS_TRITON = [1, 2, 4, 8] + SUPPORTED_BITS_TRITON = [1, 2, 4, 8, 16] SUPPORTED_DTYPES = [DType.FP16, DType.FP8, DType.INT8] def __init__( @@ -196,15 +196,18 @@ def __init__( if in_features % 128 != 0 or out_features % 128 != 0: raise NotImplementedError("Invalid input shapes") + group_size = 1 if (group_size is None) else group_size + if(group_size < 128 and (_GROUP_SIZE_WARNED is False)): warnings.warn("Make sure to enable autotuning for group_size lower than 128: `set_autotune({'GEMV_REVSPLITK':True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True})`") _GROUP_SIZE_WARNED = True + self.in_features = in_features self.out_features = out_features self.orig_shape = (out_features, in_features) self.W_nbits = W_nbits - self.group_size = group_size if group_size != -1 else in_features + self.group_size = group_size self.unpack_mask = 2**self.W_nbits - 1 self.elements_per_sample = 32 // self.W_nbits self.signature = (in_features, out_features, W_nbits, group_size) @@ -259,6 +262,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni col += 1 self.W_q = self.W_q.t().contiguous() #row-major contiguous() + + #Bias / device + self.bias = None if (bias is None) else torch.nn.Parameter(bias.to(device=self.W_q.device, dtype=self.compute_dtype)) + self.device = self.W_q.device + + #FP16 x FP16 / FP8 x FP8 / INT8 x INT8 - no meta-data case + if((scales is None) and (zeros is None)): + self.zeros = torch.tensor([[0,]]).cuda() + self.scales = torch.tensor([[1,]]).cuda() + self.W_group_mode = 0 + self.channel_scale_mode = 2 if self.scaled_activations else 0 + return + + #The rest of the use-cases require some kind of meta-data if(scales is not None): assert scales.dtype == self.meta_dtype, "Unsupported scales/zeros dtype. Only FP16 is supported." @@ -270,8 +287,9 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni self.W_group_mode = -1 #Symmetric no shift - if(zeros is None): + if(zeros is None and self.group_size > 1): assert self.scales is not None, "Zeros and scales and can't be both None for W_group_mode = 2." + self.zeros = zeros self.W_group_mode = 2 else: #Asymmetric or Symmetric with shift @@ -283,7 +301,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni self.zeros = zeros.view((self.out_features, -1)).t().contiguous() self.W_group_mode = 3 else: #Integer - self.zeros = int(zeros) + self.zeros = int(zeros) if(zeros is not None) else None if(self.scales is not None): self.W_group_mode = 3 #Symmetric with shift else: @@ -293,7 +311,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni #channel-wise scaling self.channel_scale_mode = 0 - self.meta_is_chanenlwise = self.scales.numel() == self.out_features + self.meta_is_chanenlwise = False if(self.scales is None) else self.scales.numel() == self.out_features #weight-only if((self.scaled_activations == False) and (self.meta_is_chanenlwise == True)): @@ -309,14 +327,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni self.channel_scale_mode = 3 self.W_group_mode = 1 if(self.zeros is not None) else 0 #only with fma_mode=False + if(isinstance(self.zeros, int)): #Union[Tensor, int] not supported by custom op + self.zeros = torch.tensor(self.zeros, dtype=torch.int32) + if(self.channel_scale_mode in [1, 3]): assert self.W_group_mode not in [3, 4], "Can't use channel_scale_mode with W_group_mode == 3 or 4." if(self.input_dtype == DType.INT8): assert self.W_group_mode in [1], "Only channel-wise symmetric quantization is supported for INT8 inputs." - self.bias = None if (bias is None) else torch.nn.Parameter(bias.to(device=self.W_q.device, dtype=self.compute_dtype)) - self.device = self.W_q.device + #Dummy values + if(self.zeros is None): + self.zeros = torch.tensor([[0,]]).cuda() + if(self.scales is None): + self.scales = torch.tensor([[1,]]).cuda() #TODO: Register buffers @@ -419,4 +443,4 @@ def forward_manual(self, x: Tensor, matmul_type: str="GEMM") -> Tensor: ################################################################################################################################### ################################################################################################################################### -GemLiteLinear = GemLiteLinearTriton # Triton by default +GemLiteLinear = GemLiteLinearTriton # Triton by default \ No newline at end of file diff --git a/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py b/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py index 1e5f202..9c2808e 100755 --- a/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py +++ b/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py @@ -233,6 +233,7 @@ def gemm_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor, #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) + zeros = zeros.item() if (zeros.numel()==1) else zeros grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) diff --git a/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py b/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py index d343eb2..7f4aacc 100644 --- a/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py +++ b/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py @@ -268,6 +268,7 @@ def gemm_splitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" #assert group_size >= 128, "Only group_size >= 128 is currently supported" output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) + zeros = zeros.item() if (zeros.numel()==1) else zeros grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K']) diff --git a/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py b/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py index 869ebc8..35488ac 100755 --- a/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py +++ b/gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py @@ -222,13 +222,10 @@ def gemv_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor, #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) + zeros = zeros.item() if (zeros.numel()==1) else zeros grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), triton.cdiv(K, meta['BLOCK_SIZE_K'])) - #faster to do channel-wise like this for this kernel - if(channel_scale_mode == 1 and W_group_mode == 1): - channel_scale_mode, W_group_mode = 0, 3 - gemv_A16fWnO16f_int32packing_kernel[grid]( x, W_q, output, scales, zeros, scales_x, diff --git a/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py b/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py index c995c1f..6c52f2e 100755 --- a/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py +++ b/gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py @@ -85,10 +85,13 @@ def get_autotune_config(): compute_capability = torch.cuda.get_device_capability(0) def get_default_config(): - #4090: default + # #4090: default config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32, 'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'}, num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr")) + #4090: default + #config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32}, num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr")) + if(compute_capability == (8, 0)): #A100 config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':16, 'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'}, num_warps=2, num_stages=1, pre_hook=init_to_zero("c_ptr")) @@ -157,13 +160,12 @@ def gemv_revsplitK_A16fWnO16f_int32packing_kernel( a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None] - #################################################################### #Load meta data first, for two passes k_m = (pid_k * (BLOCK_SIZE_K / group_size)).to(tl.int32) if(W_group_mode >= 2): #[2, 3, 4] - scales = tl.load(scales_ptr + offs_bn[None, :] + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + scales = tl.load(scales_ptr + offs_bn[None, :] * stride_meta_n + k_m * stride_meta_g, eviction_policy=meta_evict_policy) else: scales = None @@ -171,7 +173,7 @@ def gemv_revsplitK_A16fWnO16f_int32packing_kernel( if(zero_is_scalar): zeros = zeros_ptr else: - zeros = tl.load(zeros_ptr + offs_bn[None, :] + k_m * stride_meta_g, eviction_policy=meta_evict_policy) + zeros = tl.load(zeros_ptr + offs_bn[None, :] * stride_meta_n + k_m * stride_meta_g, eviction_policy=meta_evict_policy) else: zeros = None @@ -236,13 +238,10 @@ def gemv_revsplitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scale #assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes" output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype]) + zeros = zeros.item() if (zeros.numel()==1) else zeros grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), triton.cdiv(K, meta['BLOCK_SIZE_K'] * 2)) - #faster to do channel-wise like this for this kernel - if(channel_scale_mode == 1 and W_group_mode == 1): - channel_scale_mode, W_group_mode = 0, 3 - gemv_revsplitK_A16fWnO16f_int32packing_kernel[grid]( x, W_q, output, scales, zeros, scales_x, diff --git a/tests/test_gemlitelineartriton.py b/tests/test_gemlitelineartriton.py new file mode 100644 index 0000000..4fe4326 --- /dev/null +++ b/tests/test_gemlitelineartriton.py @@ -0,0 +1,315 @@ +#python -m unittest test_gemlitelineartriton.py + +import unittest +import torch +from gemlite.core import GemLiteLinearTriton, DType, set_autotune + +set_autotune({'GEMV_REVSPLITK':False, 'GEMV':False, 'GEMM_SPLITK':False, 'GEMM':False}, exhaustive=False, use_cuda_graph=False) + +device = 'cuda:0' +matmul_types = ['GEMV_REVSPLITK', 'GEMV', 'GEMM_SPLITK', 'GEMM'] + +def gen_data(in_features, out_features, W_nbits, group_size, dtype=torch.float16): + + W_q = torch.randint(0, 2**W_nbits - 1, (out_features, in_features), device=device).to(torch.uint8) + + shape = (out_features, in_features) + gs = W_q.numel() // group_size + scales = torch.ones((gs, 1), device=device, dtype=dtype) * 0.001 + zeros = torch.zeros((gs, 1), device=device, dtype=dtype) * ((2**W_nbits - 1)//2) + W = ((W_q.reshape([-1, group_size]) - zeros) * scales).to(torch.float8_e4m3fn).to(dtype) + + zeros = torch.mean(W_q.reshape([-1, group_size]).float() - (W / scales).float(), axis=1, keepdim=True).to(dtype) + W = ((W_q.reshape([-1, group_size]).to(dtype) - zeros) * scales) + W = W.reshape(shape) + + return W, W_q, scales, zeros + + +in_features, out_features = 4096, 4096*2 +W_nbits, group_size = 4, in_features #128 +W, W_q, scales, zeros = gen_data(in_features, out_features, W_nbits=W_nbits, group_size=group_size) + +class TestGemLiteLinearTriton(unittest.TestCase): + + def test_fp16xfp16(self): + gemlite_linear = GemLiteLinearTriton(W_nbits=16, + group_size=None, + in_features=in_features, + out_features=out_features, + input_dtype=DType.FP16, + output_dtype=DType.FP16, + scaled_activations=False) + + gemlite_linear.pack(W, None, None, None, fma_mode=False); + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #No channel-wise scaling + self.assertTrue(gemlite_linear.channel_scale_mode == 0) + + tol = 1e-3 + + x = (torch.randn((1, in_features), dtype=torch.float16, device=device) / 10.) + y_ref = torch.matmul(x.half(), W.T) + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_fp16xWn_asymmetric(self): + #FP16 x Wn / asymmetric + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=DType.FP16, + output_dtype=DType.FP16) + + + gemlite_linear.pack(W_q, scales, zeros, None, fma_mode=False); + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Since the scales are channel-wise, we perform scaling post K-sum + self.assertTrue(gemlite_linear.channel_scale_mode == 1) + + tol = 1e-3 + + x = torch.randn((1, in_features), dtype=torch.float16, device=device) / 10. + y_ref = torch.matmul(x.half(), W.T) + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_int8xWn_symmetric_no_activation_scaling(self): + #INT8 x Wn - symmetric / no scaling activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=False) + + + gemlite_linear.pack(W_q, scales=scales, zeros=7, bias=None); + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Since the scales are channel-wise, we perform scaling post K-sum + self.assertTrue(gemlite_linear.channel_scale_mode == 1) + + x = (torch.randint(-10, 10, (1, in_features), device=device)).to(torch.int8) + + tol = 1e-3 + + y_ref = torch.matmul(x.half(), (W_q.half() - 7).T) * scales + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_int8xWn_scaled_activations(self): + #INT8 x Wn - activation scaling only + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=True) + + + gemlite_linear.pack(W_q, scales=None, zeros=7, bias=None); + + #Scaling activations + scales_x = torch.ones((1, 1), dtype=torch.float16, device='cuda:0') * 0.001 + def scaled_activations(x): + return x, scales_x + gemlite_linear.scale_activations = scaled_activations + + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Activations only are scaled + self.assertTrue(gemlite_linear.channel_scale_mode == 2) + + tol = 1e-3 + + x = (torch.randint(-10, 10, (1, in_features), device=device)).to(torch.int8) + y_ref = torch.matmul(x.half(), (W_q.half() - 7).T) * scales_x + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_int8Wn_scaled_weights_scaled_activations(self): + #INT8 x Wn - activation scaling only + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=DType.INT8, + output_dtype=DType.FP32, + scaled_activations=True) + + + gemlite_linear.pack(W_q, scales=scales, zeros=7, bias=None); + + #Scaling activations + scales_x = torch.ones((1, 1), dtype=torch.float16, device='cuda:0') * 0.001 + def scaled_activations(x): + return x, scales_x + gemlite_linear.scale_activations = scaled_activations + + + #Weights are unpacked() then shifted by 7 + self.assertTrue(gemlite_linear.W_group_mode == 1) + #Activations only are scaled + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + tol = 1e-3 + + x = (torch.randint(-10, 10, (1, in_features), device=device)).to(torch.int8) + y_ref = torch.matmul(x.half(), (W_q.half() - 7).T) * (scales * scales_x) + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + + def test_fp8xfp8(self): + #FP8 x FP8 - no scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=None, + in_features=in_features, + out_features=out_features, + input_dtype=DType.FP8, + output_dtype=DType.FP16, + scaled_activations=False) + + + gemlite_linear.pack(W.to(torch.float8_e4m3fn), None, None, None, fma_mode=False); + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #No channel-wise scaling + self.assertTrue(gemlite_linear.channel_scale_mode == 0) + + tol = 5e-3 #needs higher tolerance with fp8 + + x = (torch.randn((1, in_features), dtype=torch.float16, device=device) / 10.).to(torch.float8_e4m3fn) + y_ref = torch.matmul(x.half(), W.T) + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_fp8xfp8_scaled_weights_scaled_activations(self): + #FP8 x FP8 - both activations and weights are scaled + + gemlite_linear = GemLiteLinearTriton(W_nbits=8, + group_size=in_features, + in_features=in_features, + out_features=out_features, + input_dtype=DType.FP8, + output_dtype=DType.FP16, + scaled_activations=True) + + gemlite_linear.pack(W.to(torch.float8_e4m3fn), scales=scales, zeros=None, bias=None); + + #Scaling activations + scales_x = torch.ones((1, 1), dtype=torch.float16, device='cuda:0') * 0.1 + def scaled_activations(x): + return x, scales_x + gemlite_linear.scale_activations = scaled_activations + + + #No weight unpacking / dequant + self.assertTrue(gemlite_linear.W_group_mode == 0) + #Both activations and weights are scales + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + tol = 5e-3 #needs higher tolerance with fp8 + + x = (torch.randn((1, in_features), dtype=torch.float16, device=device) / 10.).to(torch.float8_e4m3fn) + y_ref = torch.matmul(x.half(), W.T) * (scales * scales_x) + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_fp8xWn_scaled_activations(self): + #FP8 x Wn - asymmetric, with activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=DType.FP8, + output_dtype=DType.FP16, + scaled_activations=True) + + + gemlite_linear.pack(W_q, scales, zeros, None, fma_mode=False); + + #weight unpacking and shift + self.assertTrue(gemlite_linear.W_group_mode == 1) + #activations and weights are scaled psot accumulation + self.assertTrue(gemlite_linear.channel_scale_mode == 3) + + #Scaling activations + scales_x = torch.ones((1, 1), dtype=torch.float16, device='cuda:0') * 0.01 + def scaled_activations(x): + return x, scales_x + gemlite_linear.scale_activations = scaled_activations + + tol = 5e-3 #needs higher tolerance with fp8 + + x = (torch.randn((1, in_features), dtype=torch.float16, device=device) / 10.).to(torch.float8_e4m3fn) + y_ref = torch.matmul(x.half(), W.T) * scales_x + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol)) + + + def test_fp8xWn_no_activation_scaling(self): + #FP8 x Wn - asymmetric, no activation scaling + + gemlite_linear = GemLiteLinearTriton(W_nbits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + input_dtype=DType.FP8, + output_dtype=DType.FP16, + scaled_activations=False) + + gemlite_linear.pack(W_q, scales, zeros, None, fma_mode=False); + + #Weight shift only + self.assertTrue(gemlite_linear.W_group_mode == 1) + #weight scaling only - post accumulator + self.assertTrue(gemlite_linear.channel_scale_mode == 1) + + tol = 5e-3 #needs higher tolerance with fp8 + + x = (torch.randn((1, in_features), dtype=torch.float16, device=device) / 10.).to(torch.float8_e4m3fn) + y_ref = torch.matmul(x.half(), W.T) + for matmul_type in matmul_types: + y_gem = gemlite_linear.forward_manual(x, matmul_type=matmul_type) + err = (y_ref - y_gem).abs().mean().item() + self.assertTrue(err < tol, str(err) + ', expected < ' + str(tol))