Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor make_mlir_fn #100

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,156 @@ for T in [Int, AbstractFloat, AbstractString, Nothing, Type, Symbol]
@eval create_result(tocopy::$T, path, result_stores) = Meta.quot(tocopy)
end

function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=false)
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
return (true, make_mlir_fn(apply, (f, args...), kwargs, name, concretein)[2:end]...)
end

N = length(args)
seen_args = IdDict()
traced_args = ntuple(N) do i
return make_tracer(
seen_args,
args[i],
(:args, i),
concretein ? ConcreteToTraced : TracedSetPath;
toscalar,
)
end

linear_args = TracedRArray[v for v in values(seen_args) if v isa TracedRArray]

in_tys = if toscalar
[MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args]
else
[transpose_ty(mlir_type(arg)) for arg in linear_args]
end

sym_visibility = nothing
if !concretein
sym_visibility = MLIR.IR.Attribute("private")
end

mod = MLIR.IR.mmodule()
func = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name * "_tmp",
function_type=MLIR.IR.FunctionType(in_tys, []),
body=MLIR.IR.Region(),
)
end

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
push!(MLIR.IR.region(func, 1), fnbody)

@assert MLIR.IR._has_block()

result = MLIR.IR.block!(fnbody) do
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = transpose_val(raw_arg)
arg.mlir_data = row_maj_arg
end

# NOTE an `AbstractInterpreter` cannot process methods with more recent world-ages than it
# solution is to use a new interpreter, but we reuse the `code_cache` to minimize comptime in Julia <= 1.10
@static if !HAS_INTEGRATED_CACHE
interp = ReactantInterpreter(; code_cache=REACTANT_CACHE)
else
interp = ReactantInterpreter()
end

# TODO replace with `Base.invoke_within` if julia#52964 lands
ir = first(only(
# TODO fix it for kwargs
Base.code_ircode(f, map(typeof, traced_args); interp),
))

# NOTE on Julia 1.9, it appends a ghost argument at the end
# solution: manually specify argument types
@static if VERSION < v"1.10"
empty!(ir.argtypes)
if f === Reactant.apply
append!(
ir.argtypes,
Any[
Core.Const(f),
typeof(traced_args[1]),
Tuple{typeof.(traced_args[2:end])...},
],
)
else
append!(ir.argtypes, Any[Core.Const(f), typeof.(traced_args)...])
end
end

oc = Core.OpaqueClosure(ir)

if f === Reactant.apply
oc(traced_args[1], (traced_args[2:end]...,))
else
oc(traced_args...)
end
end

seen_results = IdDict()

traced_result = make_tracer(
seen_results, result, (:result,), concretein ? TracedTrack : TracedSetPath
)

# marks buffers to be donated
for i in 1:N
make_tracer(
seen_results, traced_args[i], concretein ? (:resargs, i) : (), TracedTrack
)
end

linear_results = TracedRArray[v for v in values(seen_results) if v isa TracedRArray]

out_tys = [transpose_ty(mlir_type(arg)) for arg in linear_results]

ret = MLIR.IR.block!(fnbody) do
vals = MLIR.IR.Value[]
for res in linear_results
col_maj = transpose_val(res.mlir_data)
push!(vals, col_maj)
end
@assert length(vals) == length(linear_results)
return MLIR.Dialects.func.return_(vals)
end

name2 = name

tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
for i in 0:10000
name2 = if i == 0
name
else
name * string(i)
end
if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2))
break
end
end

func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name2,
function_type=MLIR.IR.FunctionType(in_tys, out_tys),
body=MLIR.IR.Region(),
sym_visibility,
)
end
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1))

MLIR.API.mlirOperationDestroy(func.operation)
func.operation = MLIR.API.MlirOperation(C_NULL)
return false,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results
end

function compile(f, args; pipeline_options="", client=nothing)
N = length(args)
ctx = MLIR.IR.Context()
Expand Down
164 changes: 0 additions & 164 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,167 +20,3 @@ end
function apply(f, args...; kwargs...)
return f(args...; kwargs...)
end

function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=false)
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
return (true, make_mlir_fn(apply, (f, args...), kwargs, name, concretein)[2:end]...)
end

N = length(args)
seen_args = IdDict()
traced_args = ntuple(N) do i
return make_tracer(
seen_args,
args[i],
(:args, i),
concretein ? ConcreteToTraced : TracedSetPath;
toscalar,
)
end

linear_args = TracedRArray[]
for (k, v) in seen_args
if !(v isa TracedRArray)
continue
end
push!(linear_args, v)
end

in_tys = if toscalar
[MLIR.IR.TensorType((), MLIR.IR.Type(eltype(arg))) for arg in linear_args]
else
[transpose_ty(mlir_type(arg)) for arg in linear_args]
end

sym_visibility = nothing
if !concretein
sym_visibility = MLIR.IR.Attribute("private")
end

mod = MLIR.IR.mmodule()
func = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name * "_tmp",
function_type=MLIR.IR.FunctionType(in_tys, []),
body=MLIR.IR.Region(),
)
end

fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in linear_args])
push!(MLIR.IR.region(func, 1), fnbody)

@assert MLIR.IR._has_block()

result = MLIR.IR.block!(fnbody) do
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = transpose_val(raw_arg)
arg.mlir_data = row_maj_arg
end

# NOTE an `AbstractInterpreter` cannot process methods with more recent world-ages than it
# solution is to use a new interpreter, but we reuse the `code_cache` to minimize comptime in Julia <= 1.10
@static if !HAS_INTEGRATED_CACHE
interp = ReactantInterpreter(; code_cache=REACTANT_CACHE)
else
interp = ReactantInterpreter()
end

# TODO replace with `Base.invoke_within` if julia#52964 lands
ir = first(only(
# TODO fix it for kwargs
Base.code_ircode(f, map(typeof, traced_args); interp),
))

# NOTE on Julia 1.9, it appends a ghost argument at the end
# solution: manually specify argument types
@static if VERSION < v"1.10"
empty!(ir.argtypes)
if f === Reactant.apply
append!(
ir.argtypes,
Any[
Core.Const(f),
typeof(traced_args[1]),
Tuple{typeof.(traced_args[2:end])...},
],
)
else
append!(ir.argtypes, Any[Core.Const(f), typeof.(traced_args)...])
end
end

oc = Core.OpaqueClosure(ir)

if f === Reactant.apply
oc(traced_args[1], (traced_args[2:end]...,))
else
oc(traced_args...)
end
end

seen_results = IdDict()

traced_result = make_tracer(
seen_results, result, (:result,), concretein ? TracedTrack : TracedSetPath
)

# marks buffers to be donated
for i in 1:N
make_tracer(
seen_results, traced_args[i], concretein ? (:resargs, i) : (), TracedTrack
)
end

linear_results = TracedRArray[]

for (k, v) in seen_results
if !(v isa TracedRArray)
continue
end

push!(linear_results, v)
end

out_tys = [transpose_ty(mlir_type(arg)) for arg in linear_results]

ret = MLIR.IR.block!(fnbody) do
vals = MLIR.IR.Value[]
for res in linear_results
col_maj = transpose_val(res.mlir_data)
push!(vals, col_maj)
end
@assert length(vals) == length(linear_results)
return MLIR.Dialects.func.return_(vals)
end

name2 = name

tab = MLIR.IR.SymbolTable(MLIR.IR.Operation(mod))
for i in 0:10000
name2 = if i == 0
name
else
name * string(i)
end
if MLIR.IR.mlirIsNull(MLIR.API.mlirSymbolTableLookup(tab, name2))
break
end
end

func2 = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.func.func_(;
sym_name=name2,
function_type=MLIR.IR.FunctionType(in_tys, out_tys),
body=MLIR.IR.Region(),
sym_visibility,
)
end
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func2, 1), MLIR.IR.region(func, 1))

MLIR.API.mlirOperationDestroy(func.operation)
func.operation = MLIR.API.MlirOperation(C_NULL)
return false,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results
end
Loading