diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py index 36a1fb7a0..8bb0f624c 100644 --- a/bitblas/base/roller/hint.py +++ b/bitblas/base/roller/hint.py @@ -173,6 +173,10 @@ def __init__(self) -> None: # Config for block reduction self.block_reduction_depth = None # type: int + # TL Specific + # Split-K factor for SM waste optimization + self.split_k_factor: int = 1 + # Experimental self._raxis_order = [] self._step = [] diff --git a/bitblas/ops/general_flashatten/tilelang/flashatten.py b/bitblas/ops/general_flashatten/tilelang/flashatten.py index 9d76c6dfd..d2a5b2857 100644 --- a/bitblas/ops/general_flashatten/tilelang/flashatten.py +++ b/bitblas/ops/general_flashatten/tilelang/flashatten.py @@ -60,7 +60,7 @@ def apply_config( block_N=64, num_stages=2, threads=128, - enable_rasterization=False, + enable_rasterization: bool = False, ): batch, heads, seq_len, dim = self.batch, self.heads, self.seq_len, self.dim trans_K = self.trans_K @@ -185,7 +185,7 @@ def flashatten_blocked( num_stages=2, threads=128, is_causal=False, - enable_rasterization=False, # Enhance L2 Locality + enable_rasterization: bool = False, # Enhance L2 Locality ): Q_shape = (batch, seq_len, heads, dim) if not trans_Q else (batch, dim, heads, seq_len) K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len) diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py index 9b60d547d..c80d10fbb 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -47,6 +47,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 0d178e36a..4e56a15f3 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -381,7 +381,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool = False, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -578,7 +578,7 @@ def apply_config( warp_col_tiles=32, chunk=16, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): M = self.maybe_dynamic(self.M, "m") @@ -706,8 +706,8 @@ def main( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) @@ -850,7 +850,7 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool = False, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -1061,7 +1061,7 @@ def apply_config( warp_col_tiles=32, chunk=16, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): M = self.maybe_dynamic(self.M, "m") @@ -1183,8 +1183,8 @@ def main( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) @@ -1264,7 +1264,7 @@ def matmul_blocked( accum_dtype="float16", num_stages=2, threads=128, - enable_rasterization=False, # Enhance L2 Locality + enable_rasterization: bool = False, # Enhance L2 Locality ): A_shape = (K, M) if trans_A else (M, K) B_shape = (N, K) if trans_B else (K, N) @@ -1316,7 +1316,7 @@ def matmul_macro_tensorcore( warp_col_tiles, chunk, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): assert trans_A is False, "Currently only support Matrix A is not transposed" assert trans_B is True, "Currently only support Matrix B is transposed" @@ -1445,7 +1445,7 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix( warp_col_tiles, chunk, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): assert trans_A is False, "Currently only support Matrix A is not transposed" assert trans_B is True, "Currently only support Matrix B is transposed" diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py index 4bdb26f6d..8fcb53f7f 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_simt.py @@ -66,6 +66,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py index 9f0ac8165..4c91bc144 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore.py @@ -68,6 +68,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10): return self.serialize_hints_to_configs(roller_hints) def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10): + if arch is None: + arch = self.arch return self.get_roller_configs(arch, topk) # check if required shared memory cache diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py index a0242f99b..ebbdafcc6 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_finegrained.py @@ -12,7 +12,7 @@ from bitblas.ops.general_matmul.tirscript import ( matmul_dequantize_select_implementation,) from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitter) -from bitblas.base.arch import TileDevice +from bitblas.base.arch import TileDevice, is_cuda_arch from bitblas.base.roller.hint import Hint from bitblas.base.roller.rasterization import NoRasterization from bitblas.base.utils import get_roller_hints_from_func @@ -37,8 +37,9 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler): chunk: int = 32 # Usually determines the K-dimension split size # Other Optimization Parameters - num_stages: int = 2 + num_stages: int = 0 enable_rasterization: bool = False # Enhance L2 Locality + split_k_factor: int = 1 # Split-K factor for SM waste optimization class TLHint(BaseTLHint): @@ -76,6 +77,7 @@ def from_roller_hint(cls, hint: Hint): tl_hint.chunk = chunk tl_hint.num_stages = num_stages tl_hint.enable_rasterization = enable_rasterization + tl_hint.split_k_factor = hint.split_k_factor return tl_hint @@ -88,6 +90,7 @@ def get_config_params(self): "chunk": self.chunk, "num_stages": self.num_stages, "enable_rasterization": self.enable_rasterization, + "split_k_factor": self.split_k_factor, } def __repr__(self): @@ -99,7 +102,8 @@ def __repr__(self): f"block_K={self.chunk}," f"threads={self.block_row_warps * self.block_col_warps * warp_size}," f"num_stages={self.num_stages}," - f"enable_rasterization={self.enable_rasterization}" + f"enable_rasterization={self.enable_rasterization}," + f"split_k_factor={self.split_k_factor}" "}") def get_hint_type(self) -> str: @@ -108,7 +112,61 @@ def get_hint_type(self) -> str: def serialize_hints_to_configs(self, hints: List[Hint]): configs = [] for hint in hints: + # Extract static shape dimensions for matrix multiplication + M, N, K = self.M, self.N, self.K + + # Determine if the shapes are statically defined (not dynamic) + is_static_shape = isinstance(M, int) and isinstance(N, int) and isinstance(K, int) + + # Check if the architecture is CUDA-based + arch_is_cuda = is_cuda_arch(self.arch) + + # If the architecture is CUDA and we have a static shape, proceed with optimization + if arch_is_cuda and is_static_shape: + sm_waste_threshold = 5e-2 # Allow at most 5% SM waste + num_sms = self.arch.compute_max_core # Get the maximum number of streaming multiprocessors + + # Compute block sizes based on the configuration + block_M = hint.block[0] # Block size in the M dimension + block_N = hint.block[1] # Block size in the N dimension + block_K = hint.rstep[0] # Block size in the K dimension + + # Calculate the grid dimensions in M and N directions + grid_m = M // block_M + grid_n = N // block_N + total_grids = grid_m * grid_n # Total number of grids + + # Initialize the split-k factor (used to distribute K-dimension work across blocks) + split_k_factor = 1 + + # Optimize the split-k factor to minimize SM waste + while True: + # Total grids after applying split-k + total_grids_split_k = total_grids * split_k_factor + + # Calculate the waste in SMs after split-k distribution + waste_sm_splitk = total_grids_split_k - (total_grids_split_k // + num_sms) * num_sms + waste_sm_splitk_ratio = waste_sm_splitk / total_grids_split_k + + # If the SM waste ratio is within the allowed threshold, stop optimization + if waste_sm_splitk_ratio <= sm_waste_threshold: + break + + # Double the split-k factor and check if the resulting K-dimension size is too large + expand_split_k = split_k_factor * 2 + if expand_split_k * block_K >= K: + break + + # Update the split-k factor for the next iteration + split_k_factor = expand_split_k + + # Note: The optimized split_k_factor can be stored or applied to the config if needed + hint.split_k_factor = split_k_factor + + # Convert the hint to a configuration object using the TLHint mapping config = self.TLHint.from_roller_hint(hint) + configs.append(config) return configs @@ -123,6 +181,7 @@ def with_default_config(self): num_stages = getattr(self, "num_stages", 2) enable_rasterization = getattr(self, "enable_rasterization", False) + split_k_factor = getattr(self, "split_k_factor", 1) return self.apply_config( block_row_warps=block_row_warps, @@ -132,6 +191,7 @@ def with_default_config(self): chunk=chunk, num_stages=num_stages, enable_rasterization=enable_rasterization, + split_k_factor=split_k_factor, ) def apply_config( @@ -142,7 +202,8 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool = False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -204,6 +265,8 @@ def apply_config( Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits) Bias_shape = (N,) + splitK = K // split_k_factor + A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K // num_elems_per_byte) B_dequantize_shared_shape = (block_N, block_K) @@ -253,7 +316,15 @@ def apply_config( chunk=chunk, ) - cache_write_required = self.check_require_cache() + enable_split_k = split_k_factor > 1 + + def check_require_cache(): + conditions = [False] + conditions.append(self.check_require_cache()) + conditions.append(enable_split_k) + return any(conditions) + + cache_write_required = check_require_cache() @T.prim_func def general_dequant_matmul( @@ -267,7 +338,8 @@ def general_dequant_matmul( Bias: T.Buffer(Bias_shape, in_dtype), ): with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor, + threads=threads) as (bx, by, bz): A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) @@ -296,10 +368,13 @@ def general_dequant_matmul( T.clear(C_frag) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages): - T.copy(A[by * block_M, ko * block_K], A_shared) - T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared) + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + T.copy( + B[bx * block_N, + bz * (splitK // num_elems_per_byte) + ko * block_K // num_elems_per_byte], + B_shared) for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * local_size_compressed)): @@ -359,6 +434,7 @@ def general_dequant_matmul( # Matrix multiplication on fragments mma_emitter.mma(A_frag, B_frag, C_frag) + if cache_write_required: # Store the result back to C shared memory mma_emitter.stmatrix( @@ -377,13 +453,24 @@ def general_dequant_matmul( ] += Bias[bx * block_N + j] # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + if enable_split_k: + for i, j in T.Parallel(block_M, block_N // 2): + T.atomic_addx2( + C[by * block_M + i, bx * block_N + j * 2], C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ]) + else: + for i, j in T.Parallel(block_M, block_N): + C[by * block_M + i, bx * block_N + j] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: # Store the result back to C global memory mma_emitter.stmatrix( @@ -463,7 +550,8 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool = False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -471,6 +559,8 @@ def apply_config( assert warp_col_tiles is not None, "warp_col_tiles is required" assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" + # unused variable + split_k_factor = split_k_factor M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K diff --git a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py index 453a76924..eb1b5c93e 100644 --- a/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py +++ b/bitblas/ops/general_matmul/tilelang/dequantize/matmul_dequantize_tensorcore_weight_transform.py @@ -49,7 +49,8 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool = False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -140,7 +141,6 @@ def apply_config( micro_size_y, micro_size_k // num_elems_per_byte, ) - C_shared_shape = ( block_M // micro_size_x, block_N // micro_size_y, @@ -148,6 +148,8 @@ def apply_config( micro_size_y, ) + shared_scope = "shared" + import_source: Optional[str] = None func_name: str = "" if fast_decoding is True: @@ -187,12 +189,21 @@ def apply_config( num_elems_per_byte=num_elems_per_byte, ) + splitK = K // split_k_factor + enable_split_k = split_k_factor > 1 + + def check_require_cache(): + conditions = [False] + conditions.append(self.check_require_cache()) + conditions.append(enable_split_k) + return any(conditions) + + cache_write_required = check_require_cache() + vec_load_qb = 16 if block_N * block_K // num_elems_per_byte // threads < vec_load_qb: vec_load_qb = block_N * block_K // num_elems_per_byte // threads - cache_write_required = self.check_require_cache() - @T.prim_func def general_dequant_matmul( A: T.Buffer(A_shape, in_dtype), @@ -205,10 +216,11 @@ def general_dequant_matmul( Bias: T.Buffer(Bias_shape, in_dtype), ): with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): - A_shared = T.alloc_shared(A_shared_shape, in_dtype) - B_shared = T.alloc_shared(B_shared_shape, storage_dtype) - C_shared = T.alloc_shared(C_shared_shape, out_dtype) + T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor, + threads=threads) as (bx, by, bz): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + B_shared = T.alloc_shared(B_shared_shape, storage_dtype, scope=shared_scope) + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) A_frag = T.alloc_local((warp_rows * fragement_size_a), in_dtype) B_frag = T.alloc_local((warp_cols * fragement_size_b // num_elems_per_byte), @@ -229,7 +241,7 @@ def general_dequant_matmul( T.clear(C_frag) - for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages): if is_a_smooth: for i, k, ii, kk in T.Parallel( @@ -238,28 +250,25 @@ def general_dequant_matmul( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, + kk] = A[by * (block_M // micro_size_x) + i, + bz * splitK + ko * (block_K // micro_size_k) + k, ii, + kk] else: - T.copy(A[by * block_M, ko * block_K], A_shared) - - for i in T.serial(block_N * block_K // num_elems_per_byte // - (threads * vec_load_qb)): - for v in T.vectorized(0, vec_load_qb): - idx = i * threads * vec_load_qb + tx * vec_load_qb + v - vkk = idx % (micro_size_k // num_elems_per_byte) - vjj = (idx // (micro_size_k // num_elems_per_byte)) % micro_size_y - vk = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y) % ( - block_K // micro_size_k) - vj = (idx // (micro_size_k // num_elems_per_byte) // micro_size_y // - (block_K // micro_size_k)) % ( - block_N // micro_size_y) - B_shared[vj, vk, vjj, vkk] = B[ - bx * (block_N // micro_size_y) + vj, - ko * (block_K // micro_size_k) + vk, - vjj, - vkk, - ] + T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared) + + for j, k, jj, kk in T.Parallel( + block_N // micro_size_y, + block_K // micro_size_k, + micro_size_y, + (micro_size_k // num_elems_per_byte), + ): + B_shared[j, k, jj, kk] = B[ + bx * (block_N // micro_size_y) + j, + bz * (splitK // micro_size_k) + ko * (block_K // micro_size_k) + k, + jj, + kk, + ] # Perform the matrix multiplication on tensor core fragments for ki in T.serial(0, (block_K // micro_size_k)): @@ -336,14 +345,38 @@ def general_dequant_matmul( j % micro_size_y, ] += Bias[j] - # Store results from shared memory to global memory - for i, j in T.Parallel(block_M, block_N): - C[by * block_M + i, bx * block_N + j] = C_shared[ - i // micro_size_x, - j // micro_size_y, - i % micro_size_x, - j % micro_size_y, - ] + # Store results from shared memory to global memory + if enable_split_k: + # only for fp16 + if DataType(out_dtype).bits == 16: + for i, j in T.Parallel(block_M, block_N // 2): + m, n = by * block_M + i, bx * block_N + j * 2 + T.atomic_addx2( + C[m, n], C_shared[ + i // micro_size_x, + (j * 2) // micro_size_y, + i % micro_size_x, + (j * 2) % micro_size_y, + ]) + else: + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: + for i, j in T.Parallel(block_M, block_N): + m, n = by * block_M + i, bx * block_N + j + C[m, n] = C_shared[ + i // micro_size_x, + j // micro_size_y, + i % micro_size_x, + j % micro_size_y, + ] + else: mma_emitter.stmatrix( C_frag, @@ -694,7 +727,8 @@ def apply_config( warp_col_tiles: Optional[int] = None, chunk: Optional[int] = None, num_stages: Optional[int] = None, - enable_rasterization=False, + enable_rasterization: bool = False, + split_k_factor: Optional[int] = None, ): assert block_row_warps is not None, "block_row_warps is required" assert block_col_warps is not None, "block_col_warps is required" @@ -702,6 +736,8 @@ def apply_config( assert warp_col_tiles is not None, "warp_col_tiles is required" assert chunk is not None, "chunk is required" assert num_stages is not None, "num_stages is required" + # unused variable + split_k_factor = split_k_factor M = self.maybe_dynamic(self.M, "m") N, K = self.N, self.K @@ -878,8 +914,8 @@ def general_dequant_matmul( micro_size_x, micro_size_k, ): - A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x), - ko * (block_K // micro_size_k), ii, kk] + A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i, + ko * (block_K // micro_size_k) + k, ii, kk] else: T.copy(A[by * block_M, ko * block_K], A_shared) diff --git a/testing/python/operators/test_general_matmul_tilelang_impl.py b/testing/python/operators/test_general_matmul_tilelang_impl.py index 5deaeaf41..e412e2298 100644 --- a/testing/python/operators/test_general_matmul_tilelang_impl.py +++ b/testing/python/operators/test_general_matmul_tilelang_impl.py @@ -86,7 +86,7 @@ def assert_matmul_macro_tensorcore_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = matmul_macro_tensorcore( M=M, @@ -144,7 +144,7 @@ def assert_tl_matmul_with_ladder_weight_only_transform_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = matmul_macro_tensorcore_weight_propagation_level_ldmatrix( M=M, diff --git a/testing/python/operators/test_general_matmul_tilelang_kernel.py b/testing/python/operators/test_general_matmul_tilelang_kernel.py index 676543304..e89701af8 100644 --- a/testing/python/operators/test_general_matmul_tilelang_kernel.py +++ b/testing/python/operators/test_general_matmul_tilelang_kernel.py @@ -90,7 +90,7 @@ def assert_matmul_blocked_apply_config_correctness( accum_dtype="float16", num_stages=2, threads=128, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulBlockScheduler( M=M, @@ -196,7 +196,7 @@ def assert_matmul_fine_grained_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulFineGrainScheduler( @@ -316,7 +316,7 @@ def assert_matmul_weight_propagation_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulWeightPropagationScheduler( @@ -438,7 +438,7 @@ def assert_matmul_int4_fine_grained_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulINT4FineGrainScheduler( @@ -566,7 +566,7 @@ def assert_matmul_int4_weight_propagation_apply_config__correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulINT4WeightPropagationScheduler( @@ -737,7 +737,7 @@ def assert_matmul_fine_grained_dequant_int4_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulINT4DequantizeFineGrainedScheduler( M=M, @@ -941,7 +941,7 @@ def assert_matmul_weight_transform_dequant_int4_apply_config_correctness( warp_col_tiles=16, chunk=32, num_stages=2, - enable_rasterization=False, + enable_rasterization: bool = False, ): matmul = MatmulINT4DequantizeWeightPropagationScheduler( M=M,