Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change to dialect generator #691

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
594 changes: 462 additions & 132 deletions deps/ReactantExtra/tblgen/jl-generators.cc

Large diffs are not rendered by default.

40 changes: 22 additions & 18 deletions ext/ReactantAbstractFFTsExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module ReactantAbstractFFTsExt

using Reactant.MLIR.Dialects: stablehlo
using AbstractFFTs: AbstractFFTs
using Reactant: Reactant, MLIR, Ops, TracedRArray

Expand Down Expand Up @@ -31,58 +31,62 @@ function compute_correct_pdims(x::AbstractArray, dims)
end
end

for op in (:rfft, :fft, :ifft)
mode = uppercase(string(op))
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims)
for op in (stablehlo.FftType.RFFT, stablehlo.FftType.FFT, stablehlo.FftType.IFFT)
name = Symbol(lowercase(string(op)))
@eval function AbstractFFTs.$(name)(x::TracedRArray, dims)
@assert maximum(dims) ≤ ndims(x) "dims out of range"
if dims isa Integer
if dims != 1
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims)
AbstractFFTs.$(name)(permutedims(x, pdims), 1), invperm(pdims)
)
end
return generalized_fft(x, $(mode), nothing, length(dims))
return generalized_fft(x, $(op), nothing, length(dims))
end
if !check_contiguous_innermost_dims(dims, ndims(x))
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims)
AbstractFFTs.$(name)(permutedims(x, pdims), 1:length(dims)), invperm(pdims)
)
end
return generalized_fft(x, $(mode), nothing, length(dims))
return generalized_fft(x, $(op), nothing, length(dims))
end
end

for op in (:irfft,)
mode = uppercase(string(op))
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims)
for op in (stablehlo.FftType.IRFFT,)
name = Symbol(lowercase(string(op)))
@eval function AbstractFFTs.$(name)(x::TracedRArray, d::Int, dims)
@assert maximum(dims) ≤ ndims(x) "dims out of range"
if dims isa Integer
if dims != 1
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims)
AbstractFFTs.$(name)(permutedims(x, pdims), d, 1), invperm(pdims)
)
end
return generalized_fft(x, $(mode), d, length(dims))
return generalized_fft(x, $(op), d, length(dims))
end
if !check_contiguous_innermost_dims(dims, ndims(x))
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims)
AbstractFFTs.$(name)(permutedims(x, pdims), d, 1:length(dims)),
invperm(pdims),
)
end
return generalized_fft(x, $(mode), d, length(dims))
return generalized_fft(x, $(op), d, length(dims))
end
end

function generalized_fft(x::TracedRArray{T,N}, mode::String, d, first_n::Int) where {T,N}
function generalized_fft(
x::TracedRArray{T,N}, mode::stablehlo.FftType.T, d, first_n::Int
) where {T,N}
if d === nothing
@assert mode ∈ ("RFFT", "FFT", "IFFT")
@assert mode ∈
(stablehlo.FftType.RFFT, stablehlo.FftType.FFT, stablehlo.FftType.IFFT)
fft_length = [size(x, i) for i in 1:first_n]
else
@assert mode == "IRFFT"
@assert mode == stablehlo.FftType.IRFFT
fft_length = [i == 1 ? d : size(x, i) for i in 1:first_n]
end

Expand Down
17 changes: 6 additions & 11 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
wrapfunc = MLIR.IR.block!(MLIR.IR.body(mod)) do
return MLIR.Dialects.llvm.func(;
sym_name,
sym_visibility=MLIR.IR.Attribute("private"),
sym_visibility="private",
function_type=wrapftype,
body=MLIR.IR.Region(),
CConv,
Expand Down Expand Up @@ -884,10 +884,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
1,
)
alloc = MLIR.IR.result(
MLIR.Dialects.llvm.alloca(
c1; elem_type=MLIR.IR.Attribute(argty), res=llvmptr
),
1,
MLIR.Dialects.llvm.alloca(c1; elem_type=argty, res=llvmptr), 1
)
push!(allocs, (alloc, argty))

Expand Down Expand Up @@ -948,7 +945,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
MLIR.IR.Value[];
res=llvmptr,
elem_type=i8,
rawConstantIndices=MLIR.IR.Attribute([Int32(offset)]),
rawConstantIndices=[Int32(offset)],
),
1,
)
Expand All @@ -971,13 +968,11 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
wrapargs,
MLIR.IR.Value[];
callee=MLIR.IR.FlatSymbolRefAttribute(Base.String(fname)),
op_bundle_sizes=MLIR.IR.Attribute(Int32[]),
op_bundle_sizes=Int32[],
)
MLIR.Dialects.llvm.return_(nothing)
end

output_operand_aliases = MLIR.IR.Attribute(aliases)

blk_operands = MLIR.IR.Value[]
for idx in
(blockdim.x, blockdim.y, blockdim.z, threaddim.x, threaddim.y, threaddim.z, shmem)
Expand All @@ -992,9 +987,9 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
call = MLIR.Dialects.enzymexla.kernel_call(
blk_operands...,
mlir_args;
result_0=restys,
result=restys,
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),
output_operand_aliases=aliases,
)

argidx = 1
Expand Down
17 changes: 9 additions & 8 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ReactantNNlibExt
using NNlib
using GPUArraysCore: @allowscalar
using Reactant: Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber

using Reactant.MLIR.Dialects: stablehlo
using Reactant.TracedUtils:
TracedUtils, materialize_traced_array, get_mlir_data, set_mlir_data!

Expand Down Expand Up @@ -94,7 +94,7 @@ function NNlib.conv!(
Int64(output_batch_dim - 1),
Int64(output_feature_dim - 1),
length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
)
)#TODO:deal with this using a custom parser in julia code generation
#! format: on

padding = Reactant.MLIR.IR.DenseElementsAttribute(
Expand All @@ -110,11 +110,11 @@ function NNlib.conv!(
conv = Reactant.MLIR.Dialects.stablehlo.convolution(
get_mlir_data(x),
get_mlir_data(weight);
result_0=result_type,
result=result_type,
window_strides=collect(stride),
padding,
dimension_numbers,
lhs_dilation=1,
lhs_dilation=[1 for _ in dilation],
rhs_dilation=collect(dilation),
feature_group_count,
batch_group_count=1,
Expand Down Expand Up @@ -176,13 +176,14 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
end

attr = fill(Reactant.MLIR.IR.Attribute(init), unranked)

init_value = Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.constant(; value=attr)
)
reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window(
[get_mlir_data(x)],
[init_value];
result_0=[result_type],
result=[result_type],
window_dimensions,
window_strides,
window_dilations,
Expand Down Expand Up @@ -415,7 +416,7 @@ function NNlib.∇conv_filter!(
conv = MLIR.Dialects.stablehlo.convolution(
get_mlir_data(x),
get_mlir_data(dy);
result_0=result_type,
result=result_type,
window_strides=collect(dilation),
padding,
dimension_numbers,
Expand Down Expand Up @@ -532,8 +533,8 @@ function NNlib.∇conv_data!(
conv = MLIR.Dialects.stablehlo.convolution(
get_mlir_data(dy),
get_mlir_data(w);
result_0=result_type,
window_strides=1,
result=result_type,
window_strides=[1 for _ in dilation],
padding,
lhs_dilation=collect(stride),
rhs_dilation=collect(dilation),
Expand Down
9 changes: 5 additions & 4 deletions ext/ReactantRandom123Ext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ module ReactantRandom123Ext

using Random123: Threefry4x, Threefry2x, Philox4x, Philox2x
using Reactant: TracedRandom
using Reactant.MLIR.Dialects: stablehlo

TracedRandom.rng_algorithm(::Threefry4x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Threefry2x) = "THREE_FRY"
TracedRandom.rng_algorithm(::Philox4x) = "PHILOX"
TracedRandom.rng_algorithm(::Philox2x) = "PHILOX"
TracedRandom.rng_algorithm(::Threefry4x) = stablehlo.RngAlgorithm.THREE_FRY
TracedRandom.rng_algorithm(::Threefry2x) = stablehlo.RngAlgorithm.THREE_FRY
TracedRandom.rng_algorithm(::Philox4x) = stablehlo.RngAlgorithm.PHILOX
TracedRandom.rng_algorithm(::Philox2x) = stablehlo.RngAlgorithm.PHILOX

end
10 changes: 2 additions & 8 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,14 @@ function overload_autodiff(
end
end

function act_attr(val)
val = @ccall MLIR.API.mlir_c.enzymeActivityAttrGet(
MLIR.IR.context()::MLIR.API.MlirContext, val::Int32
)::MLIR.API.MlirAttribute
return MLIR.IR.Attribute(val)
end
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)(
[TracedUtils.transpose_val(v) for v in ad_inputs];
outputs=outtys,
fn=fname,
activity=MLIR.IR.Attribute([act_attr(a) for a in activity]),
ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]),
activity=[MLIR.Dialects.enzyme.Activity.T(a) for a in activity],
ret_activity=[MLIR.Dialects.enzyme.Activity.T(a) for a in ret_activity],
)

residx = 1
Expand Down
Loading
Loading