From 5ebcca036943e7b1929f9be4beeeb0cf8d3fd170 Mon Sep 17 00:00:00 2001 From: mobicham Date: Mon, 9 Dec 2024 15:40:53 +0000 Subject: [PATCH] power of 2 multiple for caching gemms --- gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py | 5 ++--- .../triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py b/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py index d19c306..6028d57 100755 --- a/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py +++ b/gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py @@ -11,12 +11,11 @@ KEYS = ['M', 'N', 'K', 'group_size', 'elements_per_sample'] MATMUL_TYPE = "GEMM" -# code based https://github.com/fpgaminer/GPTQ-triton def kernel_config_pruner(configs, nargs, **kwargs): global KEYS from ..core import GEMLITE_TRITON_CONFIG_CACHE - m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) #Need at least 16 here for tl.dot + m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) n = nargs['N'] k = nargs['K'] g = nargs['group_size'] @@ -24,7 +23,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Check cache if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): - _signature = str(tuple([nargs[i] for i in KEYS])) + _signature = str(tuple([m, n, k, g, e])) if(_signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]): _config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][_signature]) _num_stages = _config.pop('num_stages') diff --git a/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py b/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py index fc7d5a6..996d089 100644 --- a/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py +++ b/gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py @@ -15,7 +15,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): global KEYS from ..core import GEMLITE_TRITON_CONFIG_CACHE - m = nargs['M'] + m = 2 ** int(math.ceil(math.log2(nargs['M']))) n = nargs['N'] k = nargs['K'] g = nargs['group_size'] @@ -23,7 +23,7 @@ def kernel_config_pruner(configs, nargs, **kwargs): #Check cache if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE): - _signature = str(tuple([nargs[i] for i in KEYS])) + _signature = str(tuple([m, n, k, g, e])) if(_signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]): _config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][_signature]) _num_stages = _config.pop('num_stages')