Skip to content

Commit

Permalink
power of 2 multiple for caching gemms
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Dec 9, 2024
1 parent 245c9bc commit 5ebcca0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
5 changes: 2 additions & 3 deletions gemlite/triton_kernels/gemm_A16fWnO16f_int32packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@
KEYS = ['M', 'N', 'K', 'group_size', 'elements_per_sample']
MATMUL_TYPE = "GEMM"

# code based https://github.com/fpgaminer/GPTQ-triton
def kernel_config_pruner(configs, nargs, **kwargs):
global KEYS
from ..core import GEMLITE_TRITON_CONFIG_CACHE

m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16) #Need at least 16 here for tl.dot
m = max(2 ** int(math.ceil(math.log2(nargs['M']))), 16)
n = nargs['N']
k = nargs['K']
g = nargs['group_size']
e = nargs['elements_per_sample']

#Check cache
if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE):
_signature = str(tuple([nargs[i] for i in KEYS]))
_signature = str(tuple([m, n, k, g, e]))
if(_signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]):
_config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][_signature])
_num_stages = _config.pop('num_stages')
Expand Down
4 changes: 2 additions & 2 deletions gemlite/triton_kernels/gemm_splitK_A16fWnO16f_int32packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def kernel_config_pruner(configs, nargs, **kwargs):
global KEYS
from ..core import GEMLITE_TRITON_CONFIG_CACHE

m = nargs['M']
m = 2 ** int(math.ceil(math.log2(nargs['M'])))
n = nargs['N']
k = nargs['K']
g = nargs['group_size']
e = nargs['elements_per_sample']

#Check cache
if(MATMUL_TYPE in GEMLITE_TRITON_CONFIG_CACHE):
_signature = str(tuple([nargs[i] for i in KEYS]))
_signature = str(tuple([m, n, k, g, e]))
if(_signature in GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE]):
_config = copy.deepcopy(GEMLITE_TRITON_CONFIG_CACHE[MATMUL_TYPE][_signature])
_num_stages = _config.pop('num_stages')
Expand Down

0 comments on commit 5ebcca0

Please sign in to comment.