diff --git a/include/Finch/FinchOps.td b/include/Finch/FinchOps.td index 19426b5..e6cdb29 100644 --- a/include/Finch/FinchOps.td +++ b/include/Finch/FinchOps.td @@ -209,7 +209,7 @@ def Finch_GetLevelOp : Finch_Op<"getlevel", [Pure]> { }]; } -def Finch_AssignOp : Finch_Op<"assign", [AllTypesMatch<["in","out"]>]> { +def Finch_AssignOp : Finch_Op<"assign", [MemRefsNormalizable, AllTypesMatch<["in","out"]>]> { let summary = "Finch getlevel op"; let description = [{ ```mlir diff --git a/lib/Finch/FinchPasses.cpp b/lib/Finch/FinchPasses.cpp index ff737ff..b79838b 100644 --- a/lib/Finch/FinchPasses.cpp +++ b/lib/Finch/FinchPasses.cpp @@ -144,20 +144,46 @@ class FinchAssignRewriter : public OpRewritePattern { PatternRewriter &rewriter) const { Value out = op.getOperand(0); Value in = op.getOperand(1); - - Operation* defOp = out.getDefiningOp(); - if (isa(defOp)) { + + // finch.assign %out = %in + if (in == out) { + rewriter.eraseOp(op); + return success(); + } + + Operation* loadOp = out.getDefiningOp(); + if (isa(loadOp)) { + // wait until looplet pass finishes lowering return failure(); } - assert(isa(defOp) && "Currently Assign can only convert memref.load after elementlevel"); - assert(defOp->getNumOperands() == 2 && "Currently only accept non-scalar tensor"); + assert(isa(loadOp) && "Currently Assign can only convert memref.load after elementlevel"); + assert(loadOp->getNumOperands() == 2 && "Currently only accept non-scalar tensor"); + + // Value "in" is dependent to "out" + // e.g., + // %in = arith.addf %out, %1 + // finch.assign %out = %in + bool isReduction = false; + for (Operation *user : loadOp->getUsers()) { + if (in.getDefiningOp() == user) { + isReduction = true; + break; + } + } - auto sourceMemref = defOp->getOperand(0); - auto sourcePos = defOp->getOperand(1); + auto sourceMemref = loadOp->getOperand(0); + auto sourcePos = loadOp->getOperand(1); + auto storeOp = rewriter.replaceOpWithNewOp( + op, in, sourceMemref, sourcePos); - rewriter.replaceOpWithNewOp( - op, in, sourceMemref, sourcePos); + + // seriously consider replaceing this into finch.assign %out += %in + if (isReduction) { + rewriter.setInsertionPointToStart(storeOp->getBlock()); + Operation* newLoadOp = rewriter.clone(*loadOp); + rewriter.replaceOpUsesWithinBlock(loadOp, newLoadOp->getResult(0), storeOp->getBlock()); + } return success(); } diff --git a/test14.mlir b/test14.mlir new file mode 100644 index 0000000..c1b96f2 --- /dev/null +++ b/test14.mlir @@ -0,0 +1,296 @@ +// RUN: finch-opt %s | finch-opt | FileCheck %s +//./bin/finch-opt ../test12.mlir --inline --finch-simplifier --finch-instantiate --finch-looplet-pass --finch-simplifier --finch-instantiate --finch-looplet-pass --finch-simplifier --finch-instantiate --finch-looplet-pass --sparsifier | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=/Users/jaeyeonwon/llvm-project/build/lib/libmlir_runner_utils.dylib,/Users/jaeyeonwon/llvm-project/build/lib/libmlir_c_runner_utils.dylib + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +module { + + func.func private @printMemrefF32(%ptr:memref<*xf32>) attributes {llvm.emit_c_interface} + func.func private @printMemrefInd(%ptr:memref<*xindex>) attributes {llvm.emit_c_interface} + + func.func @buffers_from_sparsematrix(%jump : index) -> (memref, memref, memref, memref, memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %f1 = arith.constant 1.0 : f32 + + %6 = tensor.empty() : tensor<32x32xf32, #CSR> + %7 = scf.for %i = %c0 to %c8 step %c1 iter_args(%vin = %6) -> tensor<32x32xf32, #CSR> { + %ii = arith.muli %i, %jump : index + %tmp = scf.for %j = %c0 to %c8 step %c1 iter_args(%vin2 = %vin) -> tensor<32x32xf32, #CSR> { + %jj = arith.muli %j, %jump : index + %intj = arith.index_castui %j : index to i32 + %fj = arith.uitofp %intj : i32 to f32 + //%vout = tensor.insert %f1 into %vin2[%ii, %jj] : tensor<32x32xf32, #CSR> + %vout = tensor.insert %fj into %vin2[%ii, %jj] : tensor<32x32xf32, #CSR> + scf.yield %vout : tensor<32x32xf32, #CSR> + } + scf.yield %tmp : tensor<32x32xf32, #CSR> + } + + %8 = sparse_tensor.load %7 hasInserts : tensor<32x32xf32, #CSR> + sparse_tensor.print %8 : tensor<32x32xf32, #CSR> + + %9 = sparse_tensor.positions %8 {level = 0 :index} : tensor<32x32xf32, #CSR> to memref + %10 = sparse_tensor.coordinates %8 {level = 0 :index} : tensor<32x32xf32, #CSR> to memref + + %11 = sparse_tensor.positions %8 {level = 1 :index} : tensor<32x32xf32, #CSR> to memref + %12 = sparse_tensor.coordinates %8 {level = 1 :index} : tensor<32x32xf32, #CSR> to memref + + %13 = sparse_tensor.values %8 : tensor<32x32xf32, #CSR> to memref + return %9,%10 ,%11,%12 ,%13: memref, memref, memref, memref, memref + } + + func.func private @dense_level(%pos: index, %shape: index) -> !finch.looplet { + %l0 = finch.lookup + body = { + ^bb(%idx : index) : + %0 = arith.muli %pos, %shape : index + %1 = arith.addi %0, %idx : index + %2 = finch.nextlevel %1 : (index) -> (!finch.looplet) + finch.return %2 : !finch.looplet + } + return %l0 : !finch.looplet + } + + + func.func private @sparse_level(%pos: index, %ptr:memref, %crd:memref) -> !finch.looplet { + %fp_0 = arith.constant -0.0 : f32 + %fp_1 = arith.constant 1.0 : f32 + %c1 = arith.constant 1 : index + + %l0 = finch.stepper + seek={ + ^bb0(%idx : index): + %firstpos = func.call @binarysearch(%pos, %idx, %ptr, %crd) : (index, index, memref, memref) -> (index) + finch.return %firstpos : index + } + stop={ + ^bb(%pos1 : index): + %currcrd = memref.load %crd[%pos1] : memref + %stopub = arith.addi %currcrd, %c1 : index + finch.return %stopub : index + } + body={ + ^bb(%pos1 : index): + %currcrd = memref.load %crd[%pos1] : memref + + %zero_run = finch.run %fp_0 : (f32) -> (!finch.looplet) + %nonzero_run = finch.nextlevel %pos1 : (index) -> (!finch.looplet) + %seq = finch.sequence %currcrd, %zero_run, %nonzero_run : (index, !finch.looplet, !finch.looplet) -> (!finch.looplet) + finch.return %seq : !finch.looplet + } + next={ + ^bb0(%pos1 : index): + %nextpos = arith.addi %pos1, %c1 : index + finch.return %nextpos : index + } + + // nextpos = pos+1 + // nextoffset = ptr[pos+1] + // curroffset = ptr[pos] + // empty = curroffset == nextoffset + // if (empty) { + // return 0 + // } else { + // lastoffset = nextoffset - 1 + // last_nnz_crd = crd[lastoffset] + // last_nnz_ub = last_nnz_crd + 1 + // return last_nnz_ub + // } + %nextpos = arith.addi %pos, %c1 : index + %nextoffset = memref.load %ptr[%nextpos] : memref + %curroffset = memref.load %ptr[%pos] : memref + %empty = arith.cmpi eq, %curroffset, %nextoffset : index + %zero_ub = scf.if %empty -> (index) { + %c0 = arith.constant 0 : index + scf.yield %c0 : index + } else { + %lastoffset = arith.subi %nextoffset, %c1 : index + %last_nnz_crd = memref.load %crd[%lastoffset] : memref + %last_nnz_ub = arith.addi %last_nnz_crd, %c1 : index + scf.yield %last_nnz_ub : index + } + + %zero_run = finch.run %fp_0 : (f32) -> (!finch.looplet) + %l1 = finch.sequence %zero_ub, %l0, %zero_run : (index, !finch.looplet, !finch.looplet) -> (!finch.looplet) + + return %l1 : !finch.looplet + } + + + func.func private @element_level(%pos: index, %val : memref) -> !finch.looplet { + %currval = memref.load %val[%pos] : memref + %run = finch.run %currval : (f32) -> (!finch.looplet) + return %run : !finch.looplet + } + + func.func private @binarysearch(%pos: index, %idx : index, %ptr : memref, %crd : memref) -> index { + // i = ptr[pos]; + // while(i + + %nextpos = arith.addi %pos, %c1 : index + %nextoffset = memref.load %ptr[%nextpos] : memref + + %search = scf.while (%i = %offset) : (index) -> (index) { + %cmp1 = arith.cmpi ult, %i, %nextoffset : index + %cmp2 = scf.if %cmp1 -> (i1) { + %currcrd = memref.load %crd[%i] : memref + %cmp = arith.cmpi ult, %currcrd, %idx : index + scf.yield %cmp : i1 + } else { + %false = arith.constant 0 : i1 + scf.yield %false : i1 + } + scf.condition(%cmp2) %i : index + } do { + ^bb(%i:index) : + %next = arith.addi %i, %c1 : index + scf.yield %next : index + } + return %search : index + } + + func.func @main() { + // sparse_tensor -> extract memref (pos,crd,val) from sparse_Tensor -> build looplet representation using those memrefs + // -> perform computation with finch dialect -> lower finch loop with finch passes to llvm -> run llvm + + ///////////////////////////////// + // Defining 2D Tensor with Looplet + ///////////////////////////////// + + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %buff:5 = call @buffers_from_sparsematrix(%c2) : (index) -> (memref, memref, memref, memref, memref) + + %bx = memref.cast %buff#0 : memref to memref<*xindex> + %by = memref.cast %buff#1 : memref to memref<*xindex> + %bz = memref.cast %buff#2 : memref to memref<*xindex> + %bw = memref.cast %buff#3 : memref to memref<*xindex> + %bv = memref.cast %buff#4 : memref to memref<*xf32> + + ///////////////////////////////// + // Wrap memrefs to Looplets + ///////////////////////////////// + %ptr1A = memref.cast %buff#0 : memref to memref + %crd1A = memref.cast %buff#1 : memref to memref + %ptr2A = memref.cast %buff#2 : memref to memref + %crd2A = memref.cast %buff#3 : memref to memref + %valA = memref.cast %buff#4 : memref to memref + + %shape = arith.constant 32 : index + %ALvl0 = finch.definelevel { + ^bb0(%pos : index) : + %l = func.call @dense_level(%pos, %shape): (index, index) -> !finch.looplet + finch.return %l : !finch.looplet + } + + %ALvl1 = finch.definelevel { + ^bb0(%pos : index) : + %l = func.call @sparse_level(%pos, %ptr2A, %crd2A): (index, memref, memref) -> !finch.looplet + finch.return %l : !finch.looplet + } + + %ALvl2 = finch.definelevel { + ^bb(%pos:index): + %l = func.call @element_level(%pos, %valA): (index, memref) -> !finch.looplet + finch.return %l : !finch.looplet + } + + + %BLvl0 = finch.definelevel { + ^bb0(%pos : index) : + %l = func.call @dense_level(%pos, %shape): (index, index) -> !finch.looplet + finch.return %l : !finch.looplet + } + + %BLvl1 = finch.definelevel { + ^bb0(%pos : index) : + %l = func.call @sparse_level(%pos, %ptr2A, %crd2A): (index, memref, memref) -> !finch.looplet + finch.return %l : !finch.looplet + } + + %BLvl2 = finch.definelevel { + ^bb(%pos:index): + %l = func.call @element_level(%pos, %valA): (index, memref) -> !finch.looplet + finch.return %l : !finch.looplet + } + + + %c1 = arith.constant 1 : index + %val3 = memref.alloc(%shape) : memref + %fp_m1 = arith.constant 0.0 : f32 + + scf.for %j = %c0 to %shape step %c1 { + memref.store %fp_m1, %val3[%j] : memref + } + + %CLvl0 = finch.definelevel { + ^bb0(%pos : index) : + %l = func.call @dense_level(%pos, %shape): (index, index) -> !finch.looplet + finch.return %l : !finch.looplet + } + + %CLvl1 = finch.definelevel { + ^bb(%pos:index): + %l = func.call @element_level(%pos, %val3): (index, memref) -> !finch.looplet + finch.return %l : !finch.looplet + } + + + ///////////////////////////////////// + ////// Main Code + ///////////////////////////////////// + + %fp_0 = arith.constant 0.0 : f32 + %sum = memref.alloc() : memref + memref.store %fp_0, %sum[] : memref + + %b0 = arith.constant 0 : index + %bb = arith.constant 2 : index + %b1 = arith.constant 32 : index + + + %l0a = finch.getlevel %ALvl0, %c0 : (!finch.looplet, index) -> (!finch.looplet) + %l0b = finch.getlevel %BLvl0, %c0 : (!finch.looplet, index) -> (!finch.looplet) + %l0c = finch.getlevel %CLvl0, %c0 : (!finch.looplet, index) -> (!finch.looplet) + + scf.for %i = %b0 to %b1 step %c1 { + %p1a = finch.access %l0a, %i : index + %p1b = finch.access %l0b, %i : index + %p1c = finch.access %l0c, %i : index + + %l1c = finch.getlevel %CLvl1, %p1c : (!finch.looplet, index) -> (!finch.looplet) + %vc = finch.access %l1c, %i : f32 + + scf.for %j = %b0 to %b1 step %c1 { + %l1a = finch.getlevel %ALvl1, %p1a : (!finch.looplet, index) -> (!finch.looplet) + %p2a = finch.access %l1a, %j : index + %l1b = finch.getlevel %BLvl1, %p1b : (!finch.looplet, index) -> (!finch.looplet) + %p2b = finch.access %l1b, %j : index + + + %l2a = finch.getlevel %ALvl2, %p2a : (!finch.looplet, index) -> (!finch.looplet) + %va = finch.access %l2a, %j : f32 + %l2b = finch.getlevel %BLvl2, %p2b : (!finch.looplet, index) -> (!finch.looplet) + %vb = finch.access %l2b, %j : f32 + + %vab = arith.mulf %va, %vb : f32 + %vabc = arith.addf %vc, %vab : f32 + finch.assign %vc = %vabc : f32 + } + } + + + //// Print %sum + %z = memref.cast %val3 : memref to memref<*xf32> + call @printMemrefF32(%z): (memref<*xf32>) -> () + + return + } +}