Skip to content

Commit

Permalink
fix: improve generated mlir for wrapped arrays (#732)
Browse files Browse the repository at this point in the history
* fix: improve generated mlir for wrapped arrays

* test: add test for no gather

* fix: handle scalar index correctly
  • Loading branch information
avik-pal authored Feb 12, 2025
1 parent 904b789 commit 95f6074
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 5 deletions.
29 changes: 26 additions & 3 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module TracedRArrayOverrides

using Adapt: WrappedReshapedArray
using Base.Broadcast
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate

Expand Down Expand Up @@ -87,7 +88,7 @@ end
function generate_index_list(i1, is...)
list = reshape(i1, :, 1) .- 1
for i in is
i = reshape(i, :, 1)
i = TracedUtils.broadcast_to_size(i, (length(i), 1))
lorig = size(list, 1)
list = repeat(list, size(i, 1), 1)
i = repeat(i; inner=(lorig, 1)) .- 1
Expand Down Expand Up @@ -196,8 +197,12 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
if any(i -> unwrapped_eltype(i) <: Bool, indices)
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
end
indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...)
res = Ops.gather_getindex(a, generate_index_list(indices...))
indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils.traced_indices(
indices...
)
res = Ops.reshape(
Ops.gather_getindex(a, generate_index_list(indices...)), preddim_result_size
)
isempty(integer_indices) ||
(res = materialize_traced_array(dropdims(res; dims=integer_indices)))
return Ops.reshape(res, result_size)
Expand Down Expand Up @@ -228,6 +233,24 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...)
end

## Specialize certain dispatches for better codegen
for aType in (
WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} where {T,N,M},
PermutedDimsArray{
TracedRNumber{T},N,perm,iperm,TracedRArray{T,N}
} where {T,N,perm,iperm},
)
@eval begin
function Base.getindex(a::$aType, indices::Union{Int,TracedRNumber{Int}}...)
return getindex(materialize_traced_array(a), indices...)
end

function Base.getindex(a::$aType, indices...)
return getindex(materialize_traced_array(a), indices...)
end
end
end

function maybe_assert_scalar_setindexing(
::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N}
) where {T,N}
Expand Down
15 changes: 13 additions & 2 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ function get_ancestor_indices(
@assert length(indices) == N "Expected $N indices, got $(length(indices))"
indices = normalize_indices(x, indices...)
if any(is_traced, indices)
indices, integer_indices, result_size, flattened_size = traced_indices(indices...)
indices, integer_indices, result_size, _, flattened_size = traced_indices(
indices...
)
linear_indices = mapreduce(+, enumerate(indices)) do (i, idx)
bcasted_idxs = Ops.broadcast_in_dim(
idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size
Expand Down Expand Up @@ -704,18 +706,27 @@ end
function traced_indices(indices...)
integer_indices = Int64[]
result_size = Int64[]
preddim_result_size = Int64[]
flattened_size = Int64[length(idx) for idx in indices]
new_indices = map(enumerate(indices)) do (i, idx)
if idx isa Number
push!(preddim_result_size, 1)
push!(integer_indices, i)
idx isa TracedRNumber && return idx
return promote_to(TracedRNumber{Int}, idx)
end
append!(preddim_result_size, [size(idx)...])
append!(result_size, [size(idx)...])
idx isa TracedRArray && return materialize_traced_array(vec(idx))
return promote_to(TracedRArray{Int,1}, vec(idx))
end
return new_indices, Tuple(integer_indices), result_size, flattened_size
return (
new_indices,
Tuple(integer_indices),
result_size,
preddim_result_size,
flattened_size,
)
end

end
21 changes: 21 additions & 0 deletions test/wrapped_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,24 @@ end
x_ra = Reactant.to_rarray(rand(3, 4, 3))
@test @jit(fn(x_ra)) == fn(Array(x_ra))
end

function reshape_getindex(x)
x = reshape(x, 2, 4, 3)
return x[1, :, :]
end

function permutedims_getindex(x)
x = PermutedDimsArray(x, (2, 1))
return x[1, :]
end

@testset "no gather getindex" begin
x = ones(8, 3)
x_ra = Reactant.to_rarray(x)

hlo = repr(@code_hlo(reshape_getindex(x_ra)))
@test !occursin("stablehlo.gather", hlo)

hlo = repr(@code_hlo(permutedims_getindex(x_ra)))
@test !occursin("stablehlo.gather", hlo)
end

0 comments on commit 95f6074

Please sign in to comment.