Skip to content

Commit

Permalink
fix FP8/INT8 + add detailed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Oct 28, 2024
1 parent de0759c commit 7da572d
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 20 deletions.
40 changes: 32 additions & 8 deletions gemlite/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def get_closest_m(M):
# Triton
_GROUP_SIZE_WARNED = False;
class GemLiteLinearTriton(torch.nn.Module):
SUPPORTED_BITS_TRITON = [1, 2, 4, 8]
SUPPORTED_BITS_TRITON = [1, 2, 4, 8, 16]
SUPPORTED_DTYPES = [DType.FP16, DType.FP8, DType.INT8]

def __init__(
Expand All @@ -196,15 +196,18 @@ def __init__(
if in_features % 128 != 0 or out_features % 128 != 0:
raise NotImplementedError("Invalid input shapes")

group_size = 1 if (group_size is None) else group_size

if(group_size < 128 and (_GROUP_SIZE_WARNED is False)):
warnings.warn("Make sure to enable autotuning for group_size lower than 128: `set_autotune({'GEMV_REVSPLITK':True, 'GEMV':True, 'GEMM_SPLITK':True, 'GEMM':True})`")
_GROUP_SIZE_WARNED = True


self.in_features = in_features
self.out_features = out_features
self.orig_shape = (out_features, in_features)
self.W_nbits = W_nbits
self.group_size = group_size if group_size != -1 else in_features
self.group_size = group_size
self.unpack_mask = 2**self.W_nbits - 1
self.elements_per_sample = 32 // self.W_nbits
self.signature = (in_features, out_features, W_nbits, group_size)
Expand Down Expand Up @@ -259,6 +262,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
col += 1

self.W_q = self.W_q.t().contiguous() #row-major contiguous()

#Bias / device
self.bias = None if (bias is None) else torch.nn.Parameter(bias.to(device=self.W_q.device, dtype=self.compute_dtype))
self.device = self.W_q.device

#FP16 x FP16 / FP8 x FP8 / INT8 x INT8 - no meta-data case
if((scales is None) and (zeros is None)):
self.zeros = torch.tensor([[0,]]).cuda()
self.scales = torch.tensor([[1,]]).cuda()
self.W_group_mode = 0
self.channel_scale_mode = 2 if self.scaled_activations else 0
return

#The rest of the use-cases require some kind of meta-data

if(scales is not None):
assert scales.dtype == self.meta_dtype, "Unsupported scales/zeros dtype. Only FP16 is supported."
Expand All @@ -270,8 +287,9 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
self.W_group_mode = -1

#Symmetric no shift
if(zeros is None):
if(zeros is None and self.group_size > 1):
assert self.scales is not None, "Zeros and scales and can't be both None for W_group_mode = 2."
self.zeros = zeros
self.W_group_mode = 2
else:
#Asymmetric or Symmetric with shift
Expand All @@ -283,7 +301,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
self.zeros = zeros.view((self.out_features, -1)).t().contiguous()
self.W_group_mode = 3
else: #Integer
self.zeros = int(zeros)
self.zeros = int(zeros) if(zeros is not None) else None
if(self.scales is not None):
self.W_group_mode = 3 #Symmetric with shift
else:
Expand All @@ -293,7 +311,7 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni

#channel-wise scaling
self.channel_scale_mode = 0
self.meta_is_chanenlwise = self.scales.numel() == self.out_features
self.meta_is_chanenlwise = False if(self.scales is None) else self.scales.numel() == self.out_features

#weight-only
if((self.scaled_activations == False) and (self.meta_is_chanenlwise == True)):
Expand All @@ -309,14 +327,20 @@ def pack(self, W_q: Tensor, scales: Tensor, zeros: Union[Tensor, int], bias: Uni
self.channel_scale_mode = 3
self.W_group_mode = 1 if(self.zeros is not None) else 0 #only with fma_mode=False

if(isinstance(self.zeros, int)): #Union[Tensor, int] not supported by custom op
self.zeros = torch.tensor(self.zeros, dtype=torch.int32)

if(self.channel_scale_mode in [1, 3]):
assert self.W_group_mode not in [3, 4], "Can't use channel_scale_mode with W_group_mode == 3 or 4."

if(self.input_dtype == DType.INT8):
assert self.W_group_mode in [1], "Only channel-wise symmetric quantization is supported for INT8 inputs."

self.bias = None if (bias is None) else torch.nn.Parameter(bias.to(device=self.W_q.device, dtype=self.compute_dtype))
self.device = self.W_q.device
#Dummy values
if(self.zeros is None):
self.zeros = torch.tensor([[0,]]).cuda()
if(self.scales is None):
self.scales = torch.tensor([[1,]]).cuda()

#TODO: Register buffers

Expand Down Expand Up @@ -419,4 +443,4 @@ def forward_manual(self, x: Tensor, matmul_type: str="GEMM") -> Tensor:

###################################################################################################################################
###################################################################################################################################
GemLiteLinear = GemLiteLinearTriton # Triton by default
GemLiteLinear = GemLiteLinearTriton # Triton by default
1 change: 1 addition & 0 deletions gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def gemm_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor,

#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
zeros = zeros.item() if (zeros.numel()==1) else zeros

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def gemm_splitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales:
#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
#assert group_size >= 128, "Only group_size >= 128 is currently supported"
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
zeros = zeros.item() if (zeros.numel()==1) else zeros

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), META['SPLIT_K'])

Expand Down
5 changes: 1 addition & 4 deletions gemlite/triton_kernels/gemv_A16fWnO16f_int32packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,10 @@ def gemv_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scales: Tensor,

#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
zeros = zeros.item() if (zeros.numel()==1) else zeros

grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), triton.cdiv(K, meta['BLOCK_SIZE_K']))

#faster to do channel-wise like this for this kernel
if(channel_scale_mode == 1 and W_group_mode == 1):
channel_scale_mode, W_group_mode = 0, 3

gemv_A16fWnO16f_int32packing_kernel[grid](
x, W_q, output,
scales, zeros, scales_x,
Expand Down
15 changes: 7 additions & 8 deletions gemlite/triton_kernels/gemv_revsplitK_A16fWnO16f_int32packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ def get_autotune_config():
compute_capability = torch.cuda.get_device_capability(0)

def get_default_config():
#4090: default
# #4090: default
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32, 'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))

#4090: default
#config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':32}, num_warps=4, num_stages=2, pre_hook=init_to_zero("c_ptr"))

if(compute_capability == (8, 0)): #A100
config = triton.Config({'BLOCK_SIZE_M':1, 'BLOCK_SIZE_N':256, 'BLOCK_SIZE_K':16, 'A_load_order':2, 'meta_evict_policy':'', 'atomic_mode':'relaxed'},
num_warps=2, num_stages=1, pre_hook=init_to_zero("c_ptr"))
Expand Down Expand Up @@ -157,21 +160,20 @@ def gemv_revsplitK_A16fWnO16f_int32packing_kernel(
a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
q_shift = ((offs_k % elements_per_sample) * W_nbits).to(tl.int32)[:, None]

####################################################################
#Load meta data first, for two passes
k_m = (pid_k * (BLOCK_SIZE_K / group_size)).to(tl.int32)

if(W_group_mode >= 2): #[2, 3, 4]
scales = tl.load(scales_ptr + offs_bn[None, :] + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
scales = tl.load(scales_ptr + offs_bn[None, :] * stride_meta_n + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
else:
scales = None

if(W_group_mode == 1 or W_group_mode >= 3): #[1, 3, 4]
if(zero_is_scalar):
zeros = zeros_ptr
else:
zeros = tl.load(zeros_ptr + offs_bn[None, :] + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
zeros = tl.load(zeros_ptr + offs_bn[None, :] * stride_meta_n + k_m * stride_meta_g, eviction_policy=meta_evict_policy)
else:
zeros = None

Expand Down Expand Up @@ -236,13 +238,10 @@ def gemv_revsplitK_A16fWnO16f_int32packing_forward(x: Tensor, W_q: Tensor, scale

#assert K == W_q.shape[0] * elements_per_sample, "Invalid Input Shapes"
output = torch.empty((M, N), device=W_q.device, dtype=DTYPE_TO_TORCH[output_dtype])
zeros = zeros.item() if (zeros.numel()==1) else zeros

grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), triton.cdiv(K, meta['BLOCK_SIZE_K'] * 2))

#faster to do channel-wise like this for this kernel
if(channel_scale_mode == 1 and W_group_mode == 1):
channel_scale_mode, W_group_mode = 0, 3

gemv_revsplitK_A16fWnO16f_int32packing_kernel[grid](
x, W_q, output,
scales, zeros, scales_x,
Expand Down
Loading

0 comments on commit 7da572d

Please sign in to comment.