Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Self index stress tested #439

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion iree/turbine/kernel/ops/wave_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def allocate(
...


def self_index(
idx: IndexExpr,
dtype: DataType,
elements_per_thread: Optional[IndexExpr | int] = None,
) -> "Register":
...


def extract(
register: "Register",
offsets: tuple[IndexExpr],
Expand Down Expand Up @@ -722,7 +730,7 @@ def infer_type(self):
if lhs_dim_set.isdisjoint(rhs_dim_set):
raise ValueError(
"BinaryPyOp requires lhs and rhs shape to be at least broadcastable."
)
f" got {lhs_type.symbolic_shape} vs {rhs_type.symbolic_shape}")
broadcasted_type = lhs_type if lhs_dim_set > rhs_dim_set else rhs_type
self.type = broadcasted_type

Expand Down Expand Up @@ -934,6 +942,22 @@ def type(self) -> "Memory":
return Memory[(*self.shape, self.address_space, self.dtype)]


@define_op("self_index")
@dataclass
class SelfIndex(CustomOp):
idx: IndexExpr
dtype: DataType
elements_per_thread: Optional[IndexExpr | int]

@property
def indexing_dims(self) -> list[IndexSymbol]:
return [self.idx]

@property
def type(self) -> "Register":
return Register[(self.idx, self.dtype)]


@define_op("shared_memory_barrier")
@dataclass
class SharedMemoryBarrier(CustomOp):
Expand Down
102 changes: 70 additions & 32 deletions iree/turbine/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
reshape,
scheduling_barrier,
scheduling_group_barrier,
self_index,
set_symbol,
shared_memory_barrier,
shuffle,
Expand Down Expand Up @@ -590,6 +591,69 @@ def decorator(
###############################################################################


def _get_start_index(i: IndexSequence | IndexExpr) -> IndexExpr:
if isinstance(i, IndexSequence):
i = i.start

return i


def _get_start_indices(
src_indices: dict[IndexExpr, IndexSequence | IndexExpr]
) -> list[IndexExpr]:
start_indices = []
for dim_indexing in src_indices:
i = _get_start_index(src_indices[dim_indexing])
start_indices.append(i)

return start_indices


def _build_start_indices(
emitter: WaveEmitter,
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> list[OpResult]:
return [
gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i)
for i in _get_start_indices(src_indices)
]

@handle_op(self_index)
def handle_self_index(emitter: WaveEmitter, node: fx.Node):
try:
iterator, dtype, elements_per_thread = node.args
except ValueError as e:
raise ValidationError("Malformed arguments") from e

index = get_custom(node).index
var = index[iterator]
offset = subs_idxc(var.start)
size = elements_per_thread * subs_idxc(var.size)
stride = subs_idxc(var.stride)

start = _build_start_indices(emitter, {iterator: var})[0]

element_type = IrType.parse(dtype.ir_type_asm())
index_type = IrType.parse("index")
vector_shape = cast_py_literal(emitter, [size])

vector_index_type = VectorType.get(vector_shape, index_type)
vector_type = VectorType.get(vector_shape, element_type)

step = vector_d.step(vector_index_type)
stride_cst = arith_d.ConstantOp(
index_type,
get_constant_attr(cast_py_literal(emitter, stride), index_type))
stride_vec = vector_d.splat(vector_index_type, stride_cst)
scaled = arith_d.MulIOp(step, stride_vec)
offset = vector_d.splat(vector_index_type, start)
shifted = arith_d.AddIOp(scaled, offset)
casted_i = arith_d.IndexCastOp(vector_type, shifted).result

emitter.bind_node_proxy(node, IRProxyValue(casted_i))


@handle_op(register)
def handle_register(emitter: WaveEmitter, node: fx.Node):
try:
Expand Down Expand Up @@ -624,35 +688,6 @@ def handle_allocate(emitter: WaveEmitter, node: fx.Node):
emitter.bind_node_proxy(node, IRProxyValue(alloc))


def _get_start_index(i: IndexSequence | IndexExpr) -> IndexExpr:
if isinstance(i, IndexSequence):
i = i.start

return i


def _get_start_indices(
src_indices: dict[IndexExpr, IndexSequence | IndexExpr]
) -> list[IndexExpr]:
start_indices = []
for dim_indexing in src_indices:
i = _get_start_index(src_indices[dim_indexing])
start_indices.append(i)

return start_indices


def _build_start_indices(
emitter: WaveEmitter,
src_indices: dict[IndexExpr, IndexSequence | IndexExpr],
dynamic_values: dict[IndexExpr, Any] = {},
) -> list[OpResult]:
return [
gen_sympy_index(add_emitter_subs(emitter, dynamic_values), i)
for i in _get_start_indices(src_indices)
]


def _get_fastest_index(indices: dict[IndexExpr, IndexSequence]):
"""
This function takes in indices of a Node, extract their sizes
Expand Down Expand Up @@ -931,7 +966,8 @@ def handle_write(emitter: WaveEmitter, node: fx.Node):

assert (
tuple(insert_type.shape) == vector_shape
), f"Shape doesn't match: {tuple(insert_type.shape)} and {(vector_shape)}"
), f"Shape doesn't match: {tuple(insert_type.shape)} and {(vector_shape)}" + \
f" in register {register} and elements_per_thread {elements_per_thread}"

if not hasattr(node, "index"):
raise ValidationError("codegen expected write to have index attr.")
Expand Down Expand Up @@ -1141,7 +1177,9 @@ def handle_generic_binary(emitter: WaveEmitter, node: fx.Node):
rhs = cast_py_value(emitter, rhs)

if lhs.ir_value.type != rhs.ir_value.type:
raise ValidationError("Expected lhs and rhs to have same type.")
raise ValidationError(
"Expected lhs and rhs to have same type."
f" Got: {lhs.ir_value.type} vs {rhs.ir_value.type}")

lhs = lhs.ir_value
rhs = rhs.ir_value
Expand Down Expand Up @@ -1525,7 +1563,7 @@ def handle_broadcast(emitter: WaveEmitter, node: fx.Node):
if not VectorType.isinstance(vector_type):
raise NotImplementedError("Scalar src is not implemented yet for shuffleOp.")
assert vector_type.rank == 1
assert vector_type.shape[0] == 1
assert vector_type.shape[0] == 1, f"expected vector_type.shape[0] == 1 but got {vector_type}"

# Extract and Splat
# If by chance broadcast size matches current size, we can return src.
Expand Down
18 changes: 9 additions & 9 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
STORE_ELEMS_PER_THREAD = tkl.sym.STORE_ELEMS_PER_THREAD


@run_test
# @run_test
def test_evoformer():
# B, BN, K2, H, K1, M, N
shape = (1, 256, 256, 4, 32, 256, 32)
Expand Down Expand Up @@ -233,7 +233,7 @@ def repeat(
# The reason why we can't set K1 to be dynamic is because K1 is the
# tile size we use for expanding the K1 MMA. We could set K1 to be
# dynamic if we tiled the K1 dimension with a tile size of BLOCK_K1.
@run_test
# @run_test
def test_dynamic_attention_pipelined():
shape = (8, 128, 128, 64, 256)
# Expose user-constraints
Expand Down Expand Up @@ -373,7 +373,7 @@ def repeat(
# CHECK-COUNT-16: vector.maskedstore {{.*}}


@run_test
# @run_test
def test_attention_pipelined():
shape = (8, 128, 128, 64, 256)
# Expose user-constraints
Expand Down Expand Up @@ -499,7 +499,7 @@ def repeat(
# CHECK-COUNT-1: {{.*}} = amdgpu.mfma


@run_test
# @run_test
def test_flash_decoding():
shape = (8, 128, 128, 64, 256)
mfma_variant = tkw.MMAType.F32_16x16x16_F16
Expand Down Expand Up @@ -581,7 +581,7 @@ def test_flash_decoding():
# CHECK-COUNT-1: vector.scatter


@run_test
# @run_test
def test_attention_32x32x8():
shape = (8, 128, 128, 64, 256)
# Expose user-constraints
Expand Down Expand Up @@ -718,7 +718,7 @@ def repeat(
# CHECK-COUNT-4: vector.store {{.*}}: memref<8x128x128xf32{{.*}}>, vector<4xf32>


@run_test
# @run_test
def test_dynamic_attention_32x32x8():
shape = (8, 128, 128, 64, 256)
# Expose user-constraints
Expand Down Expand Up @@ -856,7 +856,7 @@ def repeat(
# CHECK-COUNT-3: vector.maskedstore {{.*}} : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>, vector<4xi1>, vector<4xf32>


@run_test
# @run_test
def test_attention():
shape = AttentionShape(
num_query_heads=8,
Expand Down Expand Up @@ -910,7 +910,7 @@ def test_attention():
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma


@run_test
# @run_test
def test_attention_bias():
shape = (8, 128, 128, 64, 256)
# Expose user-constraints
Expand Down Expand Up @@ -1034,7 +1034,7 @@ def repeat(
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma


@run_test
# @run_test
def test_paged_flash_decoding():
shape = paged_decode_attention_shape(
num_query_heads=128,
Expand Down
Empty file added playground/__init__.py
Empty file.
Loading