Skip to content

Commit

Permalink
Enable sweeping along N dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx committed Nov 7, 2024
1 parent b066443 commit 37a73b0
Showing 1 changed file with 32 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse

from air.ir import *
from air.dialects.affine import apply as affine_apply
from air.dialects.air import *
from air.dialects.arith import ConstantOp
from air.dialects.memref import AllocOp, DeallocOp, load, store
Expand Down Expand Up @@ -88,12 +89,12 @@ def build_module(k, n, bs, tile_k, tile_n, np_dtype_in, np_dtype_acc, np_dtype_o
func.attributes["link_with"] = StringAttr.get("vm.o")
func.attributes["llvm.emit_c_interface"] = UnitAttr.get()

# We will send an image worth of data in and out
@FuncOp.from_py_func(memrefTyA, memrefTyAS, memrefTyB, memrefTyBS, memrefTyOut)
def vecmat_i8(arg0, arg1, arg2, arg3, arg4):

# The arguments are the input and output
@launch(operands=[arg0, arg1, arg2, arg3, arg4], sizes=[1, 1])
launch_size = [1, n // tile_n]

@launch(operands=[arg0, arg1, arg2, arg3, arg4], sizes=launch_size)
def launch_body(
launch_ivx,
launch_ivy,
Expand All @@ -119,36 +120,49 @@ def launch_body(
sizes=[],
strides=[],
)

# Affine map for launch iv
launch_ivy_map = AffineMap.get(
0,
1,
[
AffineExpr.get_mul(
AffineSymbolExpr.get(0),
AffineConstantExpr.get(tile_n),
)
],
)
launch_offset_y = affine_apply(launch_ivy_map, [launch_ivy])
ChannelPut(
"bL3ToL2",
l3_b_data,
offsets=[],
sizes=[],
strides=[],
offsets=[0, launch_offset_y],
sizes=[k, tile_n],
strides=[n, 1],
)
ChannelPut(
"bL3ToL2",
l3_b_scale,
offsets=[],
sizes=[],
strides=[],
offsets=[0, launch_offset_y],
sizes=[k // bs, tile_n],
strides=[n, 1],
)
ChannelGet(
"cL2ToL3",
l3_c_data,
offsets=[],
sizes=[],
strides=[],
offsets=[launch_offset_y],
sizes=[tile_n],
strides=[1],
)

@segment(name="vecmat_i8_0")
def segment_body():
# L2 MemRefTypes
a_size_l2 = a_size
a_s_size_l2 = a_s_size
b_size_l2 = b_size
b_s_size_l2 = b_s_size
c_size_l2 = c_size
a_size_l2 = [k]
a_s_size_l2 = [k // bs]
b_size_l2 = [k, tile_n]
b_s_size_l2 = [k // bs, tile_n]
c_size_l2 = [tile_n]
l2_mem_space = IntegerAttr.get(T.i32(), MemorySpace.L2)
l2MemrefTyA = MemRefType.get(
shape=a_size_l2,
Expand Down Expand Up @@ -445,7 +459,7 @@ def herd_body(_tx, _ty, _sx, _sy):
)
ival = 0

runner = XRTRunner(verbose=args.verbose)
runner = XRTRunner(verbose=args.verbose, omit_while_true_loop=False)
exit(
runner.run_test(
mlir_module,
Expand Down

0 comments on commit 37a73b0

Please sign in to comment.