From fdf34f39e19b1bb31f97798d1999a30f714e7617 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Thu, 25 Jul 2024 15:11:42 -0500 Subject: [PATCH 1/2] Peel off the last iteration and remove masked load inside the loop And use multiple_of --- scripts/amd/gemm/matmul_kernel.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/scripts/amd/gemm/matmul_kernel.py b/scripts/amd/gemm/matmul_kernel.py index d5f854f3d8a1..6a52fb2c27a0 100644 --- a/scripts/amd/gemm/matmul_kernel.py +++ b/scripts/amd/gemm/matmul_kernel.py @@ -41,16 +41,21 @@ def matmul_kernel( bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0) acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): - if EVEN_K: - a = tl.load(a_ptrs) - b = tl.load(b_ptrs) - else: - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + max_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + + for k in range(0, max_k-1): + a = tl.load(tl.multiple_of(a_ptrs, (1, 16))) + b = tl.load(tl.multiple_of(b_ptrs, (16, 1))) accumulator += tl.dot(a, b) a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + + k = max_k - 1 + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator += tl.dot(a, b) + c = accumulator.to(c_ptr.type.element_ty) if BIAS: c += bias[:, None] From 8705aac9989f493acab6a6877d6b6d692e2106b2 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Sat, 27 Jul 2024 15:57:06 -0500 Subject: [PATCH 2/2] Workaround the bug in stream-pipeline --- scripts/amd/gemm/matmul_kernel.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/scripts/amd/gemm/matmul_kernel.py b/scripts/amd/gemm/matmul_kernel.py index 6a52fb2c27a0..0fe37fc298d6 100644 --- a/scripts/amd/gemm/matmul_kernel.py +++ b/scripts/amd/gemm/matmul_kernel.py @@ -52,8 +52,11 @@ def matmul_kernel( b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk k = max_k - 1 - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + a_ptrsX = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak + b_ptrsX = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn + a = tl.load(a_ptrsX, mask=offs_k[None, :] < K, other=0.0) + b = tl.load(b_ptrsX, mask=offs_k[:, None] < K, other=0.0) accumulator += tl.dot(a, b) c = accumulator.to(c_ptr.type.element_ty)