Skip to content

Commit

Permalink
hmm reduction working
Browse files Browse the repository at this point in the history
  • Loading branch information
nullplay committed Oct 25, 2024
1 parent ea89589 commit 8b485ab
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 10 deletions.
2 changes: 1 addition & 1 deletion include/Finch/FinchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def Finch_GetLevelOp : Finch_Op<"getlevel", [Pure]> {
}];
}

def Finch_AssignOp : Finch_Op<"assign", [AllTypesMatch<["in","out"]>]> {
def Finch_AssignOp : Finch_Op<"assign", [MemRefsNormalizable, AllTypesMatch<["in","out"]>]> {
let summary = "Finch getlevel op";
let description = [{
```mlir
Expand Down
44 changes: 35 additions & 9 deletions lib/Finch/FinchPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,46 @@ class FinchAssignRewriter : public OpRewritePattern<finch::AssignOp> {
PatternRewriter &rewriter) const {
Value out = op.getOperand(0);
Value in = op.getOperand(1);

Operation* defOp = out.getDefiningOp();
if (isa<finch::AccessOp>(defOp)) {

// finch.assign %out = %in
if (in == out) {
rewriter.eraseOp(op);
return success();
}

Operation* loadOp = out.getDefiningOp();
if (isa<finch::AccessOp>(loadOp)) {
// wait until looplet pass finishes lowering
return failure();
}

assert(isa<memref::LoadOp>(defOp) && "Currently Assign can only convert memref.load after elementlevel");
assert(defOp->getNumOperands() == 2 && "Currently only accept non-scalar tensor");
assert(isa<memref::LoadOp>(loadOp) && "Currently Assign can only convert memref.load after elementlevel");
assert(loadOp->getNumOperands() == 2 && "Currently only accept non-scalar tensor");

// Value "in" is dependent to "out"
// e.g.,
// %in = arith.addf %out, %1
// finch.assign %out = %in
bool isReduction = false;
for (Operation *user : loadOp->getUsers()) {
if (in.getDefiningOp() == user) {
isReduction = true;
break;
}
}

auto sourceMemref = defOp->getOperand(0);
auto sourcePos = defOp->getOperand(1);
auto sourceMemref = loadOp->getOperand(0);
auto sourcePos = loadOp->getOperand(1);
auto storeOp = rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, in, sourceMemref, sourcePos);

rewriter.replaceOpWithNewOp<memref::StoreOp>(
op, in, sourceMemref, sourcePos);

// seriously consider replaceing this into finch.assign %out += %in
if (isReduction) {
rewriter.setInsertionPointToStart(storeOp->getBlock());
Operation* newLoadOp = rewriter.clone(*loadOp);
rewriter.replaceOpUsesWithinBlock(loadOp, newLoadOp->getResult(0), storeOp->getBlock());
}

return success();
}
Expand Down
296 changes: 296 additions & 0 deletions test14.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
// RUN: finch-opt %s | finch-opt | FileCheck %s
//./bin/finch-opt ../test12.mlir --inline --finch-simplifier --finch-instantiate --finch-looplet-pass --finch-simplifier --finch-instantiate --finch-looplet-pass --finch-simplifier --finch-instantiate --finch-looplet-pass --sparsifier | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=/Users/jaeyeonwon/llvm-project/build/lib/libmlir_runner_utils.dylib,/Users/jaeyeonwon/llvm-project/build/lib/libmlir_c_runner_utils.dylib

#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }>
module {

func.func private @printMemrefF32(%ptr:memref<*xf32>) attributes {llvm.emit_c_interface}
func.func private @printMemrefInd(%ptr:memref<*xindex>) attributes {llvm.emit_c_interface}

func.func @buffers_from_sparsematrix(%jump : index) -> (memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%f1 = arith.constant 1.0 : f32

%6 = tensor.empty() : tensor<32x32xf32, #CSR>
%7 = scf.for %i = %c0 to %c8 step %c1 iter_args(%vin = %6) -> tensor<32x32xf32, #CSR> {
%ii = arith.muli %i, %jump : index
%tmp = scf.for %j = %c0 to %c8 step %c1 iter_args(%vin2 = %vin) -> tensor<32x32xf32, #CSR> {
%jj = arith.muli %j, %jump : index
%intj = arith.index_castui %j : index to i32
%fj = arith.uitofp %intj : i32 to f32
//%vout = tensor.insert %f1 into %vin2[%ii, %jj] : tensor<32x32xf32, #CSR>
%vout = tensor.insert %fj into %vin2[%ii, %jj] : tensor<32x32xf32, #CSR>
scf.yield %vout : tensor<32x32xf32, #CSR>
}
scf.yield %tmp : tensor<32x32xf32, #CSR>
}

%8 = sparse_tensor.load %7 hasInserts : tensor<32x32xf32, #CSR>
sparse_tensor.print %8 : tensor<32x32xf32, #CSR>

%9 = sparse_tensor.positions %8 {level = 0 :index} : tensor<32x32xf32, #CSR> to memref<?xindex>
%10 = sparse_tensor.coordinates %8 {level = 0 :index} : tensor<32x32xf32, #CSR> to memref<?xindex>

%11 = sparse_tensor.positions %8 {level = 1 :index} : tensor<32x32xf32, #CSR> to memref<?xindex>
%12 = sparse_tensor.coordinates %8 {level = 1 :index} : tensor<32x32xf32, #CSR> to memref<?xindex>

%13 = sparse_tensor.values %8 : tensor<32x32xf32, #CSR> to memref<?xf32>
return %9,%10 ,%11,%12 ,%13: memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>
}

func.func private @dense_level(%pos: index, %shape: index) -> !finch.looplet {
%l0 = finch.lookup
body = {
^bb(%idx : index) :
%0 = arith.muli %pos, %shape : index
%1 = arith.addi %0, %idx : index
%2 = finch.nextlevel %1 : (index) -> (!finch.looplet)
finch.return %2 : !finch.looplet
}
return %l0 : !finch.looplet
}


func.func private @sparse_level(%pos: index, %ptr:memref<?xindex>, %crd:memref<?xindex>) -> !finch.looplet {
%fp_0 = arith.constant -0.0 : f32
%fp_1 = arith.constant 1.0 : f32
%c1 = arith.constant 1 : index

%l0 = finch.stepper
seek={
^bb0(%idx : index):
%firstpos = func.call @binarysearch(%pos, %idx, %ptr, %crd) : (index, index, memref<?xindex>, memref<?xindex>) -> (index)
finch.return %firstpos : index
}
stop={
^bb(%pos1 : index):
%currcrd = memref.load %crd[%pos1] : memref<?xindex>
%stopub = arith.addi %currcrd, %c1 : index
finch.return %stopub : index
}
body={
^bb(%pos1 : index):
%currcrd = memref.load %crd[%pos1] : memref<?xindex>

%zero_run = finch.run %fp_0 : (f32) -> (!finch.looplet)
%nonzero_run = finch.nextlevel %pos1 : (index) -> (!finch.looplet)
%seq = finch.sequence %currcrd, %zero_run, %nonzero_run : (index, !finch.looplet, !finch.looplet) -> (!finch.looplet)
finch.return %seq : !finch.looplet
}
next={
^bb0(%pos1 : index):
%nextpos = arith.addi %pos1, %c1 : index
finch.return %nextpos : index
}

// nextpos = pos+1
// nextoffset = ptr[pos+1]
// curroffset = ptr[pos]
// empty = curroffset == nextoffset
// if (empty) {
// return 0
// } else {
// lastoffset = nextoffset - 1
// last_nnz_crd = crd[lastoffset]
// last_nnz_ub = last_nnz_crd + 1
// return last_nnz_ub
// }
%nextpos = arith.addi %pos, %c1 : index
%nextoffset = memref.load %ptr[%nextpos] : memref<?xindex>
%curroffset = memref.load %ptr[%pos] : memref<?xindex>
%empty = arith.cmpi eq, %curroffset, %nextoffset : index
%zero_ub = scf.if %empty -> (index) {
%c0 = arith.constant 0 : index
scf.yield %c0 : index
} else {
%lastoffset = arith.subi %nextoffset, %c1 : index
%last_nnz_crd = memref.load %crd[%lastoffset] : memref<?xindex>
%last_nnz_ub = arith.addi %last_nnz_crd, %c1 : index
scf.yield %last_nnz_ub : index
}

%zero_run = finch.run %fp_0 : (f32) -> (!finch.looplet)
%l1 = finch.sequence %zero_ub, %l0, %zero_run : (index, !finch.looplet, !finch.looplet) -> (!finch.looplet)

return %l1 : !finch.looplet
}


func.func private @element_level(%pos: index, %val : memref<?xf32>) -> !finch.looplet {
%currval = memref.load %val[%pos] : memref<?xf32>
%run = finch.run %currval : (f32) -> (!finch.looplet)
return %run : !finch.looplet
}

func.func private @binarysearch(%pos: index, %idx : index, %ptr : memref<?xindex>, %crd : memref<?xindex>) -> index {
// i = ptr[pos];
// while(i<ptr[pos+1] && crd[i] < idx) {
// i += 1;
// }

%c1 = arith.constant 1 : index
%offset = memref.load %ptr[%pos] : memref<?xindex>

%nextpos = arith.addi %pos, %c1 : index
%nextoffset = memref.load %ptr[%nextpos] : memref<?xindex>

%search = scf.while (%i = %offset) : (index) -> (index) {
%cmp1 = arith.cmpi ult, %i, %nextoffset : index
%cmp2 = scf.if %cmp1 -> (i1) {
%currcrd = memref.load %crd[%i] : memref<?xindex>
%cmp = arith.cmpi ult, %currcrd, %idx : index
scf.yield %cmp : i1
} else {
%false = arith.constant 0 : i1
scf.yield %false : i1
}
scf.condition(%cmp2) %i : index
} do {
^bb(%i:index) :
%next = arith.addi %i, %c1 : index
scf.yield %next : index
}
return %search : index
}

func.func @main() {
// sparse_tensor -> extract memref (pos,crd,val) from sparse_Tensor -> build looplet representation using those memrefs
// -> perform computation with finch dialect -> lower finch loop with finch passes to llvm -> run llvm

/////////////////////////////////
// Defining 2D Tensor with Looplet
/////////////////////////////////

%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%buff:5 = call @buffers_from_sparsematrix(%c2) : (index) -> (memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xindex>, memref<?xf32>)

%bx = memref.cast %buff#0 : memref<?xindex> to memref<*xindex>
%by = memref.cast %buff#1 : memref<?xindex> to memref<*xindex>
%bz = memref.cast %buff#2 : memref<?xindex> to memref<*xindex>
%bw = memref.cast %buff#3 : memref<?xindex> to memref<*xindex>
%bv = memref.cast %buff#4 : memref<?xf32> to memref<*xf32>

/////////////////////////////////
// Wrap memrefs to Looplets
/////////////////////////////////
%ptr1A = memref.cast %buff#0 : memref<?xindex> to memref<?xindex>
%crd1A = memref.cast %buff#1 : memref<?xindex> to memref<?xindex>
%ptr2A = memref.cast %buff#2 : memref<?xindex> to memref<?xindex>
%crd2A = memref.cast %buff#3 : memref<?xindex> to memref<?xindex>
%valA = memref.cast %buff#4 : memref<?xf32> to memref<?xf32>

%shape = arith.constant 32 : index
%ALvl0 = finch.definelevel {
^bb0(%pos : index) :
%l = func.call @dense_level(%pos, %shape): (index, index) -> !finch.looplet
finch.return %l : !finch.looplet
}

%ALvl1 = finch.definelevel {
^bb0(%pos : index) :
%l = func.call @sparse_level(%pos, %ptr2A, %crd2A): (index, memref<?xindex>, memref<?xindex>) -> !finch.looplet
finch.return %l : !finch.looplet
}

%ALvl2 = finch.definelevel {
^bb(%pos:index):
%l = func.call @element_level(%pos, %valA): (index, memref<?xf32>) -> !finch.looplet
finch.return %l : !finch.looplet
}


%BLvl0 = finch.definelevel {
^bb0(%pos : index) :
%l = func.call @dense_level(%pos, %shape): (index, index) -> !finch.looplet
finch.return %l : !finch.looplet
}

%BLvl1 = finch.definelevel {
^bb0(%pos : index) :
%l = func.call @sparse_level(%pos, %ptr2A, %crd2A): (index, memref<?xindex>, memref<?xindex>) -> !finch.looplet
finch.return %l : !finch.looplet
}

%BLvl2 = finch.definelevel {
^bb(%pos:index):
%l = func.call @element_level(%pos, %valA): (index, memref<?xf32>) -> !finch.looplet
finch.return %l : !finch.looplet
}


%c1 = arith.constant 1 : index
%val3 = memref.alloc(%shape) : memref<?xf32>
%fp_m1 = arith.constant 0.0 : f32

scf.for %j = %c0 to %shape step %c1 {
memref.store %fp_m1, %val3[%j] : memref<?xf32>
}

%CLvl0 = finch.definelevel {
^bb0(%pos : index) :
%l = func.call @dense_level(%pos, %shape): (index, index) -> !finch.looplet
finch.return %l : !finch.looplet
}

%CLvl1 = finch.definelevel {
^bb(%pos:index):
%l = func.call @element_level(%pos, %val3): (index, memref<?xf32>) -> !finch.looplet
finch.return %l : !finch.looplet
}


/////////////////////////////////////
////// Main Code
/////////////////////////////////////

%fp_0 = arith.constant 0.0 : f32
%sum = memref.alloc() : memref<f32>
memref.store %fp_0, %sum[] : memref<f32>

%b0 = arith.constant 0 : index
%bb = arith.constant 2 : index
%b1 = arith.constant 32 : index


%l0a = finch.getlevel %ALvl0, %c0 : (!finch.looplet, index) -> (!finch.looplet)
%l0b = finch.getlevel %BLvl0, %c0 : (!finch.looplet, index) -> (!finch.looplet)
%l0c = finch.getlevel %CLvl0, %c0 : (!finch.looplet, index) -> (!finch.looplet)

scf.for %i = %b0 to %b1 step %c1 {
%p1a = finch.access %l0a, %i : index
%p1b = finch.access %l0b, %i : index
%p1c = finch.access %l0c, %i : index

%l1c = finch.getlevel %CLvl1, %p1c : (!finch.looplet, index) -> (!finch.looplet)
%vc = finch.access %l1c, %i : f32

scf.for %j = %b0 to %b1 step %c1 {
%l1a = finch.getlevel %ALvl1, %p1a : (!finch.looplet, index) -> (!finch.looplet)
%p2a = finch.access %l1a, %j : index
%l1b = finch.getlevel %BLvl1, %p1b : (!finch.looplet, index) -> (!finch.looplet)
%p2b = finch.access %l1b, %j : index


%l2a = finch.getlevel %ALvl2, %p2a : (!finch.looplet, index) -> (!finch.looplet)
%va = finch.access %l2a, %j : f32
%l2b = finch.getlevel %BLvl2, %p2b : (!finch.looplet, index) -> (!finch.looplet)
%vb = finch.access %l2b, %j : f32

%vab = arith.mulf %va, %vb : f32
%vabc = arith.addf %vc, %vab : f32
finch.assign %vc = %vabc : f32
}
}


//// Print %sum
%z = memref.cast %val3 : memref<?xf32> to memref<*xf32>
call @printMemrefF32(%z): (memref<*xf32>) -> ()

return
}
}

0 comments on commit 8b485ab

Please sign in to comment.