From 31c2efa6d17a8a81f7f278bf4e8d7398a5eab071 Mon Sep 17 00:00:00 2001 From: Jeff Fifield Date: Thu, 16 May 2024 13:33:17 -0600 Subject: [PATCH] Use air dialect bindings for test, add xrt wrapper --- python/air/backend/xrt.py | 166 ++++++++++++++++++++++++++ test/xrt/02_mul_shim_1x1/run.py | 202 ++++++++++++++------------------ 2 files changed, 255 insertions(+), 113 deletions(-) create mode 100644 python/air/backend/xrt.py diff --git a/python/air/backend/xrt.py b/python/air/backend/xrt.py new file mode 100644 index 000000000..24f1705fc --- /dev/null +++ b/python/air/backend/xrt.py @@ -0,0 +1,166 @@ +# ./python/air/backend/xrt_backend.py -*- Python -*- +# +# Copyright (C) 2024, Advanced Micro Devices, Inc. +# SPDX-License-Identifier: MIT + +import air.ir +import air.passmanager + +from .abc import AirBackend + +import air.compiler.util +import air.compiler.aircc.main as aircc + +import numpy as np +import pyxrt as xrt + + +class XRTBackend(AirBackend): + """Main entry-point for the xrt based AIR backend. + + Args: + verbose: verbose + xclbin: xclbin filename to use + kernel: kernel name to use + insts: instruction filename to use + """ + + def __init__( + self, + verbose=False, + xclbin="air.xclbin", + kernel="MLIR_AIE", + insts="air.insts.txt", + ): + super().__init__() + self.opts_xclbin = xclbin + self.opts_kernel = kernel + self.opts_insts = insts + self.verbose = verbose + + def __del__(self): + self.unload() + + def compile(self, air_module: air.ir.Module, pipeline=None): + """Compiles an AIR module for the NPU / XRT Runtime with aircc. + + The module is expected to be AIR dialect IR. Unless 'pipeline' is + specified, the the input IR is passed directly to aircc. If 'pipeline' + is specified, it is passed to aircc as the 'pipeline' command line options. + + Args: + air_module: The MLIR module consisting of funcs in the AIR dialect. + pipeline: aircc optimization pipeline to use. + verbose: verbose + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + + with air.ir.Context(): + + if self.verbose: + print("AIR Module:") + print(air_module) + + aircc_options = [ + "--device", + "npu1_4col", + "air.mlir", + "-xchesscc", + "-xbridge", + "-o", + self.opts_xclbin, + ] + + if self.verbose: + aircc_options = aircc_options + ["-v"] + + aircc.run(air_module, aircc_options) + + return air_module + + def load(self, module): + """Load a compiled artifact into the air runtime. + + Returns: A callable that can be used to invoke the loaded module. + The callable takes a list of numpy arrays. Each numpy array is + assumed to be an input/output tensor. The callable also returns a + list of numpy arrays, one for each tensor.""" + + # create the device, xclbin and context + self.device = xrt.device(0) + self.xclbin = xrt.xclbin(self.opts_xclbin) + self.device.register_xclbin(self.xclbin) + self.context = xrt.hw_context(self.device, self.xclbin.get_uuid()) + + # find and load the kernel + kernels = self.xclbin.get_kernels() + try: + xkernel = [k for k in kernels if self.opts_kernel in k.get_name()][0] + except: + print(f"Kernel '{self.opts_kernel}' not found in '{self.opts_xclbin}'") + exit(-1) + self.kernel = xrt.kernel(self.context, xkernel.get_name()) + + # load the instructions as a numpy array + with open(self.opts_insts, "r") as f: + instr_text = f.read().split("\n") + instr_text = [l for l in instr_text if l != ""] + self.instr_v = np.array([int(i, 16) for i in instr_text], dtype=np.uint32) + + self.bo_instr = xrt.bo( + self.device, + len(self.instr_v) * 4, + xrt.bo.cacheable, + self.kernel.group_id(0), + ) + self.bo_instr.write(self.instr_v, 0) + + # 1) create and sync the buffers + # 2) invoke the kernel + # 3) sync the buffers + # 4) return the contents of the buffers + def invoker(*args): + + # limit arg length to 5 + if len(args) > 5: + raise ValueError("Too many arguments") + sizes_in_bytes = [a.size * a.itemsize for a in args] + bos = [ + xrt.bo(self.device, s, xrt.bo.host_only, self.kernel.group_id(i + 2)) + for i, s in enumerate(sizes_in_bytes) + ] + + self.bo_instr.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + for i, a in enumerate(args): + bos[i].write(a, 0) + bos[i].sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) + + h = self.kernel(self.bo_instr, len(self.instr_v), *bos) + h.wait() + + for i, a in enumerate(args): + bos[i].sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) + return tuple( + [ + bos[i].read(s, 0).view(args[i].dtype) + for i, s in enumerate(sizes_in_bytes) + ] + ) + + return invoker + + def compile_and_load(self, module): + """Compile and load a module in one step.""" + c = self.compile(module) + return self.load(c) + + def unload(self): + """Unload any loaded module and shutdown the air runtime.""" + self.kernel = None + self.context = None + self.xclbin = None + self.device = None + self.bo_instr = None + self.instr_v = None diff --git a/test/xrt/02_mul_shim_1x1/run.py b/test/xrt/02_mul_shim_1x1/run.py index 1f22eaa9c..e3c77fe12 100644 --- a/test/xrt/02_mul_shim_1x1/run.py +++ b/test/xrt/02_mul_shim_1x1/run.py @@ -1,24 +1,24 @@ -# run.py -*- Python -*- -# -# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2024, Advanced Micro Devices, Inc. # SPDX-License-Identifier: MIT +# RUN: %PYTHON %s | FileCheck %s -import air -from air.ir import * -from air.passmanager import * -from air.dialects import air as airdialect -from air.dialects import arith, func, linalg, memref -from air.dialects.linalg.opdsl.lang import * -from air.compiler.util import run_transform +import air.backend.xrt as xrt_backend import air.compiler.aircc.main as aircc +from air.dialects.air import * +from air.dialects.func import FuncOp, ReturnOp +from air.dialects.linalg import elemwise_binary +from air.dialects.linalg.opdsl.lang import BinaryFn, TypeFn +from air.dialects.memref import AllocOp, DeallocOp +from air.dialects.scf import for_, yield_ +from air.ir import * import numpy as np -import pyxrt as xrt -import sys import filelock from bfloat16 import bfloat16 +verbose = False + sizes = [ [1024], ] @@ -32,135 +32,111 @@ # (bfloat16, bfloat16), ] -opts_xclbin = 'aie.xclbin' -opts_kernel = 'MLIR_AIE' -opts_insts = opts_xclbin.removesuffix('.xclbin') + ".insts.txt" def to_type(dtype): if dtype == np.int32: - return IntegerType.get_signless(32) + return T.i32() if dtype == np.int16: - return IntegerType.get_signless(16) + return T.i16() if dtype == np.float32: return F32Type.get() if dtype == bfloat16: return BF16Type.get() return None -def generate_add_module(shape, idtype, odtype): - module = Module.create() - with InsertionPoint(module.body): - @func.FuncOp.from_py_func( - MemRefType.get(shape, idtype), MemRefType.get(shape, idtype), MemRefType.get(shape, odtype)) - def mul(lhs, rhs, out): - linalg.elemwise_binary( - lhs, - rhs, - outs=[out], - fun=BinaryFn.mul, - cast=TypeFn.cast_unsigned) - return - - transform_ir_string = """ - transform.with_pdl_patterns { - ^bb0(%arg0: !pdl.operation): - pdl.pattern @match_copy : benefit(1) { - %args = pdl.operands - %results = pdl.types - %op = pdl.operation "memref.copy"(%args : !pdl.range) -> (%results : !pdl.range) - pdl.rewrite %op with "transform.dialect" - } - transform.sequence %arg0 : !pdl.operation failures(propagate) { - ^bb1(%arg1: !pdl.operation): - %l0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation - %l1, %outer_tile_loops:1 = transform.air.linalg_tile %l0 [1024] - %l2, %inner_tile_loops:1 = transform.air.linalg_tile %l1 [32] - transform.air.linalg_promote %l2 {"operands_to_promote"=[0,1,2], "memory_space"="L1"} - %herds = transform.air.par_to_herd %outer_tile_loops#0 - %copies = transform.pdl_match @match_copy in %arg0 : (!pdl.operation) -> !pdl.operation - %h = transform.air.copy_to_dma %copies - } - } - """ - - pm = PassManager.parse('builtin.module(func.func(linalg-generalize-named-ops))') - pm.run(module.operation) - transform_ir = Module.parse(transform_ir_string) - run_transform(transform_ir, module) - - pm = PassManager.parse('builtin.module(func.func(canonicalize,cse))') - pm.run(module.operation) - return module -def run_test(size, idtype, odtype): +def build_module(shape, idtype, odtype, tile_size): with Context() as ctx, Location.unknown(): - mlir_input_type = to_type(idtype) - mlir_output_type = to_type(odtype) + module = Module.create() + with InsertionPoint(module.body): + memrefTyIn = MemRefType.get(shape, to_type(idtype)) + memrefTyOut = MemRefType.get(shape, to_type(odtype)) + ChannelOp("ChanA") + ChannelOp("ChanB") + ChannelOp("ChanC") + + @FuncOp.from_py_func(memrefTyIn, memrefTyIn, memrefTyOut) + def mul(arg0, arg1, arg2): + @launch(operands=[arg0, arg1, arg2]) + def launch_body(a, b, c): + ChannelPut("ChanA", [], a) + ChannelPut("ChanB", [], b) + ChannelGet("ChanC", [], c) + + @segment(name="segment_0") + def segment_body(): + @herd(name="herd_0", sizes=[1, 1]) + def herd_body(x, y, sx, sy): + mem_space = IntegerAttr.get(T.i32(), MemorySpace.L1) + itile_type = MemRefType.get( + shape=[tile_size], + element_type=to_type(idtype), + memory_space=mem_space, + ) + otile_type = MemRefType.get( + shape=[tile_size], + element_type=to_type(odtype), + memory_space=mem_space, + ) + for _ in for_(shape[0] // tile_size): + tile_a = AllocOp(itile_type, [], []) + tile_b = AllocOp(itile_type, [], []) + tile_c = AllocOp(otile_type, [], []) + ChannelGet("ChanA", [], tile_a) + ChannelGet("ChanB", [], tile_b) + elemwise_binary( + tile_a, + tile_b, + outs=[tile_c], + fun=BinaryFn.mul, + cast=TypeFn.cast_unsigned, + ) + ChannelPut("ChanC", [], tile_c) + DeallocOp(tile_a) + DeallocOp(tile_b) + DeallocOp(tile_c) + yield_([]) + HerdTerminatorOp() + + SegmentTerminatorOp() + + LaunchTerminatorOp() + + return module - mlir_module = generate_add_module(size, mlir_input_type, mlir_output_type) - aircc_options = ['--device', 'npu1_1col', 'air.mlir', '-xchesscc', '-xbridge', '-o', opts_xclbin] - aircc.run(mlir_module, aircc_options) +def run_test(size, idtype, odtype): - with open(opts_insts, 'r') as f: - instr_text = f.read().split('\n') - instr_text = [l for l in instr_text if l != ''] - instr_v = np.array([int(i,16) for i in instr_text], dtype=np.uint32) + mlir_module = build_module(size, idtype, odtype, 32) input_a = (np.random.rand(*size) * 127).astype(idtype).reshape(size) input_b = (np.random.rand(*size) * 127).astype(idtype).reshape(size) ref = (input_a * input_b).astype(odtype) + input_c = np.ones_like(ref) - with filelock.FileLock("/tmp/npu.lock"): - device = xrt.device(0) - xclbin = xrt.xclbin(opts_xclbin) - kernels = xclbin.get_kernels() - try: - xkernel = [k for k in kernels if opts_kernel in k.get_name()][0] - except: - print(f"Kernel '{opts_kernel}' not found in '{opts_xclbin}'") - exit(-1) - - device.register_xclbin(xclbin) - context = xrt.hw_context(device, xclbin.get_uuid()) - kernel = xrt.kernel(context, xkernel.get_name()) + backend = xrt_backend.XRTBackend(verbose=verbose) - in_size_bytes = input_a.size * input_a.itemsize - out_size_bytes = ref.size * ref.itemsize - - bo_instr = xrt.bo(device, len(instr_v)*4, xrt.bo.cacheable, kernel.group_id(0)) - bo_a = xrt.bo(device, in_size_bytes, xrt.bo.host_only, kernel.group_id(2)) - bo_b = xrt.bo(device, in_size_bytes, xrt.bo.host_only, kernel.group_id(3)) - bo_c = xrt.bo(device, out_size_bytes, xrt.bo.host_only, kernel.group_id(4)) - - bo_instr.write(instr_v, 0) - bo_instr.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - - bo_a.write(input_a, 0) - bo_a.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - - bo_b.write(input_b, 0) - bo_b.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_TO_DEVICE) - - h = kernel(bo_instr, len(instr_v), bo_a, bo_b, bo_c) - h.wait() + # run the module + with filelock.FileLock("/tmp/npu.lock"): + mul = backend.compile_and_load(mlir_module) + (_, _, output_c) = mul(input_a, input_b, input_c) - bo_c.sync(xrt.xclBOSyncDirection.XCL_BO_SYNC_BO_FROM_DEVICE) - output_buffer = bo_c.read(out_size_bytes, 0).view(odtype) + backend.unload() - print("input:", input_a) - print("input:", input_b) - print("output:", output_buffer) + print("inputA:", input_a) + print("inputB:", input_b) + print("output:", output_c) - if np.allclose(ref, output_buffer, 0.01): + if np.allclose(ref, output_c, 0.01): print("PASS!") return 1 else: print("failed.") return 0 + passed = 0 -for (idtype, odtype) in dtypes: +for idtype, odtype in dtypes: for size in sizes: try: print("Testing size:", size, "dtype:", idtype, odtype) @@ -168,10 +144,10 @@ def run_test(size, idtype, odtype): except Exception as e: print(e) -num_tests = len(sizes)*len(dtypes) +num_tests = len(sizes) * len(dtypes) if passed != num_tests: - print (f"failed. {passed}/{num_tests}") + print(f"failed. {passed}/{num_tests}") exit(-1) else: - print (f"PASSED! {passed}/{num_tests}") + print(f"PASSED! {passed}/{num_tests}") exit(0)