Skip to content

Commit

Permalink
Fixes for nightly (#42)
Browse files Browse the repository at this point in the history
* fix operand_segment_sizes for LLVM > 15

DenseI32

* fix io conversion to void ptr

* handle additional param to mlirExecutionEngineCreate

* fix executionengine tests

* fix version checks
  • Loading branch information
Pangoraw authored Jan 17, 2024
1 parent 67749a7 commit 176e4f0
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 13 deletions.
8 changes: 6 additions & 2 deletions examples/brutus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ end

# ---

using Test
using Test, LLVM
using MLIR.IR, MLIR

fptr = IR.context!(IR.Context()) do
Expand All @@ -292,7 +292,11 @@ fptr = IR.context!(IR.Context()) do

IR.run!(pm, mod)

jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL)
jit = if LLVM.version() >= v"16"
MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL, false)
else
MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL)
end
MLIR.API.mlirExecutionEngineLookup(jit, "pow")
end

Expand Down
10 changes: 8 additions & 2 deletions src/Dialects.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Dialects

import LLVM
import ..IR: Attribute, NamedAttribute, context
import ..IR: Attribute, NamedAttribute, DenseArrayAttribute, context
import ..API

namedattribute(name, val) = namedattribute(name, Attribute(val))
Expand All @@ -11,7 +11,13 @@ function namedattribute(name, val::NamedAttribute)
return val
end

operandsegmentsizes(segments) = namedattribute("operand_segment_sizes", Attribute(Int32.(segments)))
operandsegmentsizes(segments) =
namedattribute("operand_segment_sizes",
LLVM.version() >= v"16" ?
DenseArrayAttribute(Int32.(segments)) :
Attribute(Int32.(segments))
)


let
ver = string(LLVM.version().major)
Expand Down
5 changes: 5 additions & 0 deletions src/IR/IR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,11 @@ function ArrayAttribute(attributes::Vector{Attribute})
API.mlirArrayAttrGet(context(), length(attributes), attributes),
)
end
function DenseArrayAttribute(values::AbstractVector{Int32})
Attribute(
API.mlirDenseI32ArrayGet(context(), length(values), collect(values))
)
end
function DenseArrayAttribute(values::AbstractVector{Int})
Attribute(
API.mlirDenseI64ArrayGet(context(), length(values), collect(values))
Expand Down
4 changes: 2 additions & 2 deletions src/IR/Pass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function add_pipeline!(op_pass::OpPassManager, pipeline)
@static if isdefined(API, :mlirOpPassManagerAddPipeline)
io = IOBuffer()
c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any))
result = GC.@preserve io API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io)
result = API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, Ref(io))
if mlirLogicalResultIsFailure(result)
exc = AddPipelineException(String(take!(io)))
throw(exc)
Expand Down Expand Up @@ -171,4 +171,4 @@ end
mlir_pass
end

end
end
25 changes: 18 additions & 7 deletions test/executionengine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ function registerAllUpstreamDialects!(ctx)
return nothing
end

# TODO: Fix for LLVM 15
function lowerModuleToLLVM(ctx, mod)
pm = MLIR.API.mlirPassManagerCreate(ctx)
if LLVM.version() >= v"15"
Expand All @@ -26,17 +25,24 @@ function lowerModuleToLLVM(ctx, mod)
end
opm = MLIR.API.mlirPassManagerGetNestedUnder(pm, op)
if LLVM.version() >= v"15"
MLIR.API.mlirPassManagerAddOwnedPass(pm,
MLIR.API.mlirPassManagerAddOwnedPass(pm,
MLIR.API.mlirCreateConversionConvertFuncToLLVM()
)
else
MLIR.API.mlirPassManagerAddOwnedPass(pm,
MLIR.API.mlirCreateConversionConvertStandardToLLVM()
)
end
MLIR.API.mlirOpPassManagerAddOwnedPass(opm,
MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()
)

if LLVM.version() >= v"16"
MLIR.API.mlirOpPassManagerAddOwnedPass(opm,
MLIR.API.mlirCreateConversionArithToLLVMConversionPass()
)
else
MLIR.API.mlirOpPassManagerAddOwnedPass(opm,
MLIR.API.mlirCreateConversionConvertArithmeticToLLVM()
)
end
status = MLIR.API.mlirPassManagerRun(pm, mod)
# undefined symbol: mlirLogicalResultIsFailure
if status.value == 0
Expand Down Expand Up @@ -74,8 +80,13 @@ MLIR.API.mlirRegisterAllLLVMTranslations(ctx)

# TODO add C-API for translateModuleToLLVMIR

jit = MLIR.API.mlirExecutionEngineCreate(
mod, #=optLevel=# 2, #=numPaths=# 0, #=sharedLibPaths=# C_NULL)
jit = if LLVM.version() >= v"16"
MLIR.API.mlirExecutionEngineCreate(
mod, #=optLevel=# 2, #=numPaths=# 0, #=sharedLibPaths=# C_NULL, #= enableObjectDump =# false)
else
MLIR.API.mlirExecutionEngineCreate(
mod, #=optLevel=# 2, #=numPaths=# 0, #=sharedLibPaths=# C_NULL)
end

if jit == C_NULL
error("Execution engine creation failed")
Expand Down

0 comments on commit 176e4f0

Please sign in to comment.