Skip to content

Commit

Permalink
[TileLang][Dev] Enhance Layout Inference Pass to infer with complex p…
Browse files Browse the repository at this point in the history
…arallel primitives (#268)

* fix for relax

* lint fix

* save import bitblas time

* bug fix for tl backend

* support input transform_kind

* hint identifier

* annotate hint type for dequantize

* enhance swizzling

* Enhance for hardware aware tuning

* test fix

* remove pad factor

* introduce legalize dyanmic pass

* update 3rdparty

* testfix

* test code  commit

* enhance typing and fix test for int4 dequantize gemm

* lint fix
  • Loading branch information
LeiWang1999 authored Dec 16, 2024
1 parent fe8e435 commit f250ec5
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 80 deletions.
4 changes: 4 additions & 0 deletions bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def __init__(self) -> None:
# Config for block reduction
self.block_reduction_depth = None # type: int

# TL Specific
# Split-K factor for SM waste optimization
self.split_k_factor: int = 1

# Experimental
self._raxis_order = []
self._step = []
Expand Down
4 changes: 2 additions & 2 deletions bitblas/ops/general_flashatten/tilelang/flashatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def apply_config(
block_N=64,
num_stages=2,
threads=128,
enable_rasterization=False,
enable_rasterization: bool = False,
):
batch, heads, seq_len, dim = self.batch, self.heads, self.seq_len, self.dim
trans_K = self.trans_K
Expand Down Expand Up @@ -185,7 +185,7 @@ def flashatten_blocked(
num_stages=2,
threads=128,
is_causal=False,
enable_rasterization=False, # Enhance L2 Locality
enable_rasterization: bool = False, # Enhance L2 Locality
):
Q_shape = (batch, seq_len, heads, dim) if not trans_Q else (batch, dim, heads, seq_len)
K_shape = (batch, seq_len, heads, dim) if not trans_K else (batch, dim, heads, seq_len)
Expand Down
2 changes: 2 additions & 0 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
return self.serialize_hints_to_configs(roller_hints)

def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10):
if arch is None:
arch = self.arch
return self.get_roller_configs(arch, topk)

# check if required shared memory cache
Expand Down
22 changes: 11 additions & 11 deletions bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
Expand Down Expand Up @@ -578,7 +578,7 @@ def apply_config(
warp_col_tiles=32,
chunk=16,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):

M = self.maybe_dynamic(self.M, "m")
Expand Down Expand Up @@ -706,8 +706,8 @@ def main(
micro_size_x,
micro_size_k,
):
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x),
ko * (block_K // micro_size_k), ii, kk]
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i,
ko * (block_K // micro_size_k) + k, ii, kk]
else:
T.copy(A[by * block_M, ko * block_K], A_shared)

Expand Down Expand Up @@ -850,7 +850,7 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def apply_config(
warp_col_tiles=32,
chunk=16,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):

M = self.maybe_dynamic(self.M, "m")
Expand Down Expand Up @@ -1183,8 +1183,8 @@ def main(
micro_size_x,
micro_size_k,
):
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x),
ko * (block_K // micro_size_k), ii, kk]
A_shared[i, k, ii, kk] = A[by * (block_M // micro_size_x) + i,
ko * (block_K // micro_size_k) + k, ii, kk]
else:
T.copy(A[by * block_M, ko * block_K], A_shared)

Expand Down Expand Up @@ -1264,7 +1264,7 @@ def matmul_blocked(
accum_dtype="float16",
num_stages=2,
threads=128,
enable_rasterization=False, # Enhance L2 Locality
enable_rasterization: bool = False, # Enhance L2 Locality
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
Expand Down Expand Up @@ -1316,7 +1316,7 @@ def matmul_macro_tensorcore(
warp_col_tiles,
chunk,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert trans_A is False, "Currently only support Matrix A is not transposed"
assert trans_B is True, "Currently only support Matrix B is transposed"
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def matmul_macro_tensorcore_weight_propagation_level_ldmatrix(
warp_col_tiles,
chunk,
num_stages=2,
enable_rasterization=False,
enable_rasterization: bool = False,
):
assert trans_A is False, "Currently only support Matrix A is not transposed"
assert trans_B is True, "Currently only support Matrix B is transposed"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
return self.serialize_hints_to_configs(roller_hints)

def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10):
if arch is None:
arch = self.arch
return self.get_roller_configs(arch, topk)

# check if required shared memory cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def get_roller_configs(self, arch: TileDevice = None, topk: int = 10):
return self.serialize_hints_to_configs(roller_hints)

def get_hardware_aware_configs(self, arch: TileDevice = None, topk=10):
if arch is None:
arch = self.arch
return self.get_roller_configs(arch, topk)

# check if required shared memory cache
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bitblas.ops.general_matmul.tirscript import (
matmul_dequantize_select_implementation,)
from bitblas.tl.mma_macro_generator import (TensorCoreIntrinEmitter, INT4TensorCoreIntrinEmitter)
from bitblas.base.arch import TileDevice
from bitblas.base.arch import TileDevice, is_cuda_arch
from bitblas.base.roller.hint import Hint
from bitblas.base.roller.rasterization import NoRasterization
from bitblas.base.utils import get_roller_hints_from_func
Expand All @@ -37,8 +37,9 @@ class MatmulDequantizeFineGrainedScheduler(MatmulDequantizeBaseScheduler):
chunk: int = 32 # Usually determines the K-dimension split size

# Other Optimization Parameters
num_stages: int = 2
num_stages: int = 0
enable_rasterization: bool = False # Enhance L2 Locality
split_k_factor: int = 1 # Split-K factor for SM waste optimization

class TLHint(BaseTLHint):

Expand Down Expand Up @@ -76,6 +77,7 @@ def from_roller_hint(cls, hint: Hint):
tl_hint.chunk = chunk
tl_hint.num_stages = num_stages
tl_hint.enable_rasterization = enable_rasterization
tl_hint.split_k_factor = hint.split_k_factor

return tl_hint

Expand All @@ -88,6 +90,7 @@ def get_config_params(self):
"chunk": self.chunk,
"num_stages": self.num_stages,
"enable_rasterization": self.enable_rasterization,
"split_k_factor": self.split_k_factor,
}

def __repr__(self):
Expand All @@ -99,7 +102,8 @@ def __repr__(self):
f"block_K={self.chunk},"
f"threads={self.block_row_warps * self.block_col_warps * warp_size},"
f"num_stages={self.num_stages},"
f"enable_rasterization={self.enable_rasterization}"
f"enable_rasterization={self.enable_rasterization},"
f"split_k_factor={self.split_k_factor}"
"}")

def get_hint_type(self) -> str:
Expand All @@ -108,7 +112,61 @@ def get_hint_type(self) -> str:
def serialize_hints_to_configs(self, hints: List[Hint]):
configs = []
for hint in hints:
# Extract static shape dimensions for matrix multiplication
M, N, K = self.M, self.N, self.K

# Determine if the shapes are statically defined (not dynamic)
is_static_shape = isinstance(M, int) and isinstance(N, int) and isinstance(K, int)

# Check if the architecture is CUDA-based
arch_is_cuda = is_cuda_arch(self.arch)

# If the architecture is CUDA and we have a static shape, proceed with optimization
if arch_is_cuda and is_static_shape:
sm_waste_threshold = 5e-2 # Allow at most 5% SM waste
num_sms = self.arch.compute_max_core # Get the maximum number of streaming multiprocessors

# Compute block sizes based on the configuration
block_M = hint.block[0] # Block size in the M dimension
block_N = hint.block[1] # Block size in the N dimension
block_K = hint.rstep[0] # Block size in the K dimension

# Calculate the grid dimensions in M and N directions
grid_m = M // block_M
grid_n = N // block_N
total_grids = grid_m * grid_n # Total number of grids

# Initialize the split-k factor (used to distribute K-dimension work across blocks)
split_k_factor = 1

# Optimize the split-k factor to minimize SM waste
while True:
# Total grids after applying split-k
total_grids_split_k = total_grids * split_k_factor

# Calculate the waste in SMs after split-k distribution
waste_sm_splitk = total_grids_split_k - (total_grids_split_k //
num_sms) * num_sms
waste_sm_splitk_ratio = waste_sm_splitk / total_grids_split_k

# If the SM waste ratio is within the allowed threshold, stop optimization
if waste_sm_splitk_ratio <= sm_waste_threshold:
break

# Double the split-k factor and check if the resulting K-dimension size is too large
expand_split_k = split_k_factor * 2
if expand_split_k * block_K >= K:
break

# Update the split-k factor for the next iteration
split_k_factor = expand_split_k

# Note: The optimized split_k_factor can be stored or applied to the config if needed
hint.split_k_factor = split_k_factor

# Convert the hint to a configuration object using the TLHint mapping
config = self.TLHint.from_roller_hint(hint)

configs.append(config)
return configs

Expand All @@ -123,6 +181,7 @@ def with_default_config(self):

num_stages = getattr(self, "num_stages", 2)
enable_rasterization = getattr(self, "enable_rasterization", False)
split_k_factor = getattr(self, "split_k_factor", 1)

return self.apply_config(
block_row_warps=block_row_warps,
Expand All @@ -132,6 +191,7 @@ def with_default_config(self):
chunk=chunk,
num_stages=num_stages,
enable_rasterization=enable_rasterization,
split_k_factor=split_k_factor,
)

def apply_config(
Expand All @@ -142,7 +202,8 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
split_k_factor: Optional[int] = None,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
Expand Down Expand Up @@ -204,6 +265,8 @@ def apply_config(
Qzeros_shape = ((K // group_size), N // storage_nbit * num_bits)
Bias_shape = (N,)

splitK = K // split_k_factor

A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K // num_elems_per_byte)
B_dequantize_shared_shape = (block_N, block_K)
Expand Down Expand Up @@ -253,7 +316,15 @@ def apply_config(
chunk=chunk,
)

cache_write_required = self.check_require_cache()
enable_split_k = split_k_factor > 1

def check_require_cache():
conditions = [False]
conditions.append(self.check_require_cache())
conditions.append(enable_split_k)
return any(conditions)

cache_write_required = check_require_cache()

@T.prim_func
def general_dequant_matmul(
Expand All @@ -267,7 +338,8 @@ def general_dequant_matmul(
Bias: T.Buffer(Bias_shape, in_dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k_factor,
threads=threads) as (bx, by, bz):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
Expand Down Expand Up @@ -296,10 +368,13 @@ def general_dequant_matmul(

T.clear(C_frag)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=num_stages):

T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[bx * block_N, ko * block_K // num_elems_per_byte], B_shared)
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
T.copy(
B[bx * block_N,
bz * (splitK // num_elems_per_byte) + ko * block_K // num_elems_per_byte],
B_shared)

for i in T.serial(block_N * block_K // num_elems_per_byte //
(threads * local_size_compressed)):
Expand Down Expand Up @@ -359,6 +434,7 @@ def general_dequant_matmul(

# Matrix multiplication on fragments
mma_emitter.mma(A_frag, B_frag, C_frag)

if cache_write_required:
# Store the result back to C shared memory
mma_emitter.stmatrix(
Expand All @@ -377,13 +453,24 @@ def general_dequant_matmul(
] += Bias[bx * block_N + j]

# Store results from shared memory to global memory
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]
if enable_split_k:
for i, j in T.Parallel(block_M, block_N // 2):
T.atomic_addx2(
C[by * block_M + i, bx * block_N + j * 2], C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
])
else:
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
i // micro_size_x,
j // micro_size_y,
i % micro_size_x,
j % micro_size_y,
]

else:
# Store the result back to C global memory
mma_emitter.stmatrix(
Expand Down Expand Up @@ -463,14 +550,17 @@ def apply_config(
warp_col_tiles: Optional[int] = None,
chunk: Optional[int] = None,
num_stages: Optional[int] = None,
enable_rasterization=False,
enable_rasterization: bool = False,
split_k_factor: Optional[int] = None,
):
assert block_row_warps is not None, "block_row_warps is required"
assert block_col_warps is not None, "block_col_warps is required"
assert warp_row_tiles is not None, "warp_row_tiles is required"
assert warp_col_tiles is not None, "warp_col_tiles is required"
assert chunk is not None, "chunk is required"
assert num_stages is not None, "num_stages is required"
# unused variable
split_k_factor = split_k_factor

M = self.maybe_dynamic(self.M, "m")
N, K = self.N, self.K
Expand Down
Loading

0 comments on commit f250ec5

Please sign in to comment.