From c81e9fa5ae5692b77bf52b247a0c00a8f3966a6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 7 Feb 2024 23:18:56 +0100 Subject: [PATCH 1/6] Init linalg example --- examples/linalg.jl | 53 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 examples/linalg.jl diff --git a/examples/linalg.jl b/examples/linalg.jl new file mode 100644 index 00000000..a7346ff4 --- /dev/null +++ b/examples/linalg.jl @@ -0,0 +1,53 @@ +using LLVM +using MLIR.IR +using MLIR.Dialects: arith, linalg, func + +n = 128 +a = rand(Float64, n, n) +b = rand(Float64, n, n) + +fptr = IR.context!(IR.Context()) do + mod = IR.Module(Location()) + body = IR.get_body(mod) + + # Create a function + block = IR.Block() + op = linalg.matmul(...) # TODO + push!(block, op) + + region = IR.Region() + push!(region, block) + + ftype = IR.FunctionType( # TODO + inputs=MLIRType[...], + results=MLIRType[...], + ) + f = func.func_(; + sym_name=IR.Attribute("matmul_demo"), + function_type=IR.Attribute(...), # TODO + owned_regions=Region[region], + result_inference=false, + ) + push!(body, f) + + pm = IR.PassManager() + opm = IR.OpPassManager(pm) + + IR.enable_ir_printing!(pm) + IR.enable_verifier!(pm, true) + + MLIR.API.mlirRegisterAllPasses() + MLIR.API.mlirRegisterAllLLVMTranslations(IR.context()) + IR.add_pipeline!(opm, "convert-linalg-to-loops,convert-func-to-llvm") + + IR.run!(pm, mod) + + jit = if LLVM.version() >= v"16" + MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL, false) + else + MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) + end + MLIR.API.mlirExecutionEngineLookup(jit, "matmul_demo") +end + +@test ccall(fptr, Ptr{Float64}, (Ptr{Float64}, Ptr{Float64}), a, b) ≈ a * b From c35bc719ce9533d4cc0e39050c12193db825879f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 8 Feb 2024 18:15:25 +0100 Subject: [PATCH 2/6] Update code --- examples/linalg.jl | 53 +++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/examples/linalg.jl b/examples/linalg.jl index a7346ff4..ef8742d6 100644 --- a/examples/linalg.jl +++ b/examples/linalg.jl @@ -1,32 +1,45 @@ -using LLVM using MLIR.IR +using MLIR.API using MLIR.Dialects: arith, linalg, func n = 128 a = rand(Float64, n, n) b = rand(Float64, n, n) +function linalg_matmul(a::IR.Value, b::IR.Value; result::Union{Nothing,IR.MLIRType}=nothing, location=IR.Location()) + IR.create_operation( + "linalg.matmul", location; + operands=Value[a, b], + owned_regions=IR.Region[], + successors=IR.Block[], + results=isnothing(result) ? nothing : MLIRType[result], + attributes=IR.NamedAttribute[], + ) +end + fptr = IR.context!(IR.Context()) do + IR.enable_multithreading!(false) + mod = IR.Module(Location()) body = IR.get_body(mod) - # Create a function + # Create a function + mattype = MLIRType(API.mlirRankedTensorTypeGet(2, [n, n], MLIRType(Float64), API.mlirAttributeGetNull())) + block = IR.Block() - op = linalg.matmul(...) # TODO + a_ir = IR.push_argument!(block, mattype, IR.Location()) + b_ir = IR.push_argument!(block, mattype, IR.Location()) + op = linalg_matmul(a_ir, b_ir; result=mattype) # TODO refactor to `linalg.matmul` when bindings are regenerated push!(block, op) region = IR.Region() push!(region, block) - ftype = IR.FunctionType( # TODO - inputs=MLIRType[...], - results=MLIRType[...], - ) + ftype = MLIRType(API.mlirFunctionTypeGet(IR.context(), 2, [mattype, mattype], 1, [mattype])) f = func.func_(; sym_name=IR.Attribute("matmul_demo"), - function_type=IR.Attribute(...), # TODO - owned_regions=Region[region], - result_inference=false, + function_type=IR.Attribute(ftype), + body=region, ) push!(body, f) @@ -36,18 +49,18 @@ fptr = IR.context!(IR.Context()) do IR.enable_ir_printing!(pm) IR.enable_verifier!(pm, true) - MLIR.API.mlirRegisterAllPasses() - MLIR.API.mlirRegisterAllLLVMTranslations(IR.context()) - IR.add_pipeline!(opm, "convert-linalg-to-loops,convert-func-to-llvm") + API.mlirRegisterAllPasses() + API.mlirRegisterAllLLVMTranslations(IR.context()) + IR.add_pipeline!(opm, "convert-func-to-llvm") IR.run!(pm, mod) - jit = if LLVM.version() >= v"16" - MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL, false) - else - MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) - end - MLIR.API.mlirExecutionEngineLookup(jit, "matmul_demo") + # jit = if LLVM.version() >= v"16" + # API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL, false) + # else + # API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) + # end + # API.mlirExecutionEngineLookup(jit, "matmul_demo") end -@test ccall(fptr, Ptr{Float64}, (Ptr{Float64}, Ptr{Float64}), a, b) ≈ a * b +# @test ccall(fptr, Ptr{Float64}, (Ptr{Float64}, Ptr{Float64}), a, b) ≈ a * b From 4e6a0d01be9c126009818fc22d3cb2883e6e1664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 8 Feb 2024 21:57:10 +0100 Subject: [PATCH 3/6] Update code --- examples/linalg.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/linalg.jl b/examples/linalg.jl index ef8742d6..d903cd9c 100644 --- a/examples/linalg.jl +++ b/examples/linalg.jl @@ -6,11 +6,12 @@ n = 128 a = rand(Float64, n, n) b = rand(Float64, n, n) +# TODO remove this when bindings are regenerated function linalg_matmul(a::IR.Value, b::IR.Value; result::Union{Nothing,IR.MLIRType}=nothing, location=IR.Location()) IR.create_operation( "linalg.matmul", location; operands=Value[a, b], - owned_regions=IR.Region[], + owned_regions=IR.Region[IR.Region()], successors=IR.Block[], results=isnothing(result) ? nothing : MLIRType[result], attributes=IR.NamedAttribute[], @@ -20,6 +21,10 @@ end fptr = IR.context!(IR.Context()) do IR.enable_multithreading!(false) + for dialect in ["func", "linalg"] + IR.get_or_load_dialect!(dialect) + end + mod = IR.Module(Location()) body = IR.get_body(mod) @@ -32,6 +37,8 @@ fptr = IR.context!(IR.Context()) do op = linalg_matmul(a_ir, b_ir; result=mattype) # TODO refactor to `linalg.matmul` when bindings are regenerated push!(block, op) + push!(block, func.return_([IR.get_result(op)])) + region = IR.Region() push!(region, block) @@ -41,6 +48,7 @@ fptr = IR.context!(IR.Context()) do function_type=IR.Attribute(ftype), body=region, ) + IR.verifyall(f) push!(body, f) pm = IR.PassManager() From f1bcfb8e307d21971ab8e2b57b551e3858211d6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 12 Feb 2024 12:55:15 +0000 Subject: [PATCH 4/6] Update code --- examples/linalg.jl | 44 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/examples/linalg.jl b/examples/linalg.jl index d903cd9c..304fefa3 100644 --- a/examples/linalg.jl +++ b/examples/linalg.jl @@ -1,20 +1,32 @@ using MLIR.IR using MLIR.API -using MLIR.Dialects: arith, linalg, func +using MLIR.Dialects: arith, linalg, func, operandsegmentsizes n = 128 a = rand(Float64, n, n) b = rand(Float64, n, n) +function linalg_yield(x::IR.Value; location=IR.Location()) + IR.create_operation( + "linalg.yield", location; + operands=[x], + owned_regions=IR.Region[], + successors=IR.Block[], + results=nothing, + result_inference=false, + attributes=IR.NamedAttribute[operandsegmentsizes([1])], + ) +end + # TODO remove this when bindings are regenerated -function linalg_matmul(a::IR.Value, b::IR.Value; result::Union{Nothing,IR.MLIRType}=nothing, location=IR.Location()) +function linalg_matmul(c::IR.Value, a::IR.Value, b::IR.Value; region, result::Union{Nothing,IR.MLIRType}=nothing, location=IR.Location()) IR.create_operation( "linalg.matmul", location; - operands=Value[a, b], - owned_regions=IR.Region[IR.Region()], + operands=[a, b, c], + owned_regions=IR.Region[region], successors=IR.Block[], results=isnothing(result) ? nothing : MLIRType[result], - attributes=IR.NamedAttribute[], + attributes=IR.NamedAttribute[operandsegmentsizes([2, 1])], ) end @@ -29,12 +41,28 @@ fptr = IR.context!(IR.Context()) do body = IR.get_body(mod) # Create a function - mattype = MLIRType(API.mlirRankedTensorTypeGet(2, [n, n], MLIRType(Float64), API.mlirAttributeGetNull())) + scalartype = MLIRType(Float64) + mattype = MLIRType(API.mlirRankedTensorTypeGet(2, [n, n], scalartype, API.mlirAttributeGetNull())) + # mattype = MLIRType(API.mlirMemRefTypeContiguousGet(MLIRType(Float64), 2, [n, n], API.mlirAttributeGetNull())) + + linalg_block = IR.Block() + arg0 = IR.push_argument!(linalg_block, scalartype, IR.Location()) + arg1 = IR.push_argument!(linalg_block, scalartype, IR.Location()) + op = arith.addf(arg0, arg1; result=scalartype) + push!(linalg_block, op) + + op = linalg_yield(IR.get_result(op)) + push!(linalg_block, op) + + linalg_region = IR.Region() + push!(linalg_region, linalg_block) block = IR.Block() + c_ir = IR.push_argument!(block, mattype, IR.Location()) a_ir = IR.push_argument!(block, mattype, IR.Location()) b_ir = IR.push_argument!(block, mattype, IR.Location()) - op = linalg_matmul(a_ir, b_ir; result=mattype) # TODO refactor to `linalg.matmul` when bindings are regenerated + + op = linalg_matmul(c_ir, a_ir, b_ir; region=linalg_region, result=mattype) # TODO refactor to `linalg.matmul` when bindings are regenerated push!(block, op) push!(block, func.return_([IR.get_result(op)])) @@ -59,7 +87,7 @@ fptr = IR.context!(IR.Context()) do API.mlirRegisterAllPasses() API.mlirRegisterAllLLVMTranslations(IR.context()) - IR.add_pipeline!(opm, "convert-func-to-llvm") + # IR.add_pipeline!(opm, "convert-func-to-llvm") IR.run!(pm, mod) From 5d8791c0d068f50dfeed3c2570c1fcfd02cc2b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 12 Feb 2024 15:42:39 +0000 Subject: [PATCH 5/6] Fix `linalg.matmul` construction --- examples/linalg.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/linalg.jl b/examples/linalg.jl index 304fefa3..ffdd6561 100644 --- a/examples/linalg.jl +++ b/examples/linalg.jl @@ -48,7 +48,11 @@ fptr = IR.context!(IR.Context()) do linalg_block = IR.Block() arg0 = IR.push_argument!(linalg_block, scalartype, IR.Location()) arg1 = IR.push_argument!(linalg_block, scalartype, IR.Location()) - op = arith.addf(arg0, arg1; result=scalartype) + arg2 = IR.push_argument!(linalg_block, scalartype, IR.Location()) + op = arith.mulf(arg0, arg1; result=scalartype) + push!(linalg_block, op) + + op = arith.addf(arg2, IR.get_result(op)) push!(linalg_block, op) op = linalg_yield(IR.get_result(op)) From 3bf61ff127d937c329af8039ff459d342382f7ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 13 Feb 2024 16:11:58 +0000 Subject: [PATCH 6/6] Update passes --- examples/linalg.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/linalg.jl b/examples/linalg.jl index ffdd6561..9c6a4c2e 100644 --- a/examples/linalg.jl +++ b/examples/linalg.jl @@ -1,3 +1,4 @@ +using LLVM using MLIR.IR using MLIR.API using MLIR.Dialects: arith, linalg, func, operandsegmentsizes @@ -74,7 +75,7 @@ fptr = IR.context!(IR.Context()) do region = IR.Region() push!(region, block) - ftype = MLIRType(API.mlirFunctionTypeGet(IR.context(), 2, [mattype, mattype], 1, [mattype])) + ftype = MLIRType(API.mlirFunctionTypeGet(IR.context(), 3, [mattype, mattype, mattype], 1, [mattype])) f = func.func_(; sym_name=IR.Attribute("matmul_demo"), function_type=IR.Attribute(ftype), @@ -91,16 +92,22 @@ fptr = IR.context!(IR.Context()) do API.mlirRegisterAllPasses() API.mlirRegisterAllLLVMTranslations(IR.context()) - # IR.add_pipeline!(opm, "convert-func-to-llvm") + IR.add_pipeline!(opm, "any(convert-func-to-llvm)") + IR.add_pipeline!(opm, "any(linalg-bufferize)") + IR.add_pipeline!(opm, "any(buffer-deallocation)") + IR.add_pipeline!(opm, "any(convert-linalg-to-loops)") + IR.add_pipeline!(opm, "any(convert-scf-to-cf)") + IR.add_pipeline!(opm, "any(convert-cf-to-llvm)") + IR.add_pipeline!(opm, "any(convert-arith-to-llvm)") IR.run!(pm, mod) - # jit = if LLVM.version() >= v"16" - # API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL, false) - # else - # API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) - # end - # API.mlirExecutionEngineLookup(jit, "matmul_demo") + jit = if LLVM.version() >= v"16" + API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL, false) + else + API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) + end + API.mlirExecutionEngineLookup(jit, "matmul_demo") end # @test ccall(fptr, Ptr{Float64}, (Ptr{Float64}, Ptr{Float64}), a, b) ≈ a * b