diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 6105bca78..c2e2e0cb9 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -1,5 +1,6 @@ module TracedRArrayOverrides +using Adapt: WrappedReshapedArray using Base.Broadcast using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate @@ -87,7 +88,7 @@ end function generate_index_list(i1, is...) list = reshape(i1, :, 1) .- 1 for i in is - i = reshape(i, :, 1) + i = TracedUtils.broadcast_to_size(i, (length(i), 1)) lorig = size(list, 1) list = repeat(list, size(i, 1), 1) i = repeat(i; inner=(lorig, 1)) .- 1 @@ -196,8 +197,12 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N} if any(i -> unwrapped_eltype(i) <: Bool, indices) error("Boolean indexing with TracedRArrays isn't fully supported yet.") end - indices, integer_indices, result_size, _ = TracedUtils.traced_indices(indices...) - res = Ops.gather_getindex(a, generate_index_list(indices...)) + indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils.traced_indices( + indices... + ) + res = Ops.reshape( + Ops.gather_getindex(a, generate_index_list(indices...)), preddim_result_size + ) isempty(integer_indices) || (res = materialize_traced_array(dropdims(res; dims=integer_indices))) return Ops.reshape(res, result_size) @@ -228,6 +233,24 @@ function Base.getindex(a::WrappedTracedRArray, indices...) return getindex(ancestor(a), TracedUtils.get_ancestor_indices(a, indices...)...) end +## Specialize certain dispatches for better codegen +for aType in ( + WrappedReshapedArray{TracedRNumber{T},N,TracedRArray{T,M}} where {T,N,M}, + PermutedDimsArray{ + TracedRNumber{T},N,perm,iperm,TracedRArray{T,N} + } where {T,N,perm,iperm}, +) + @eval begin + function Base.getindex(a::$aType, indices::Union{Int,TracedRNumber{Int}}...) + return getindex(materialize_traced_array(a), indices...) + end + + function Base.getindex(a::$aType, indices...) + return getindex(materialize_traced_array(a), indices...) + end + end +end + function maybe_assert_scalar_setindexing( ::TracedRArray{T,N}, ::Vararg{Union{Int,TracedRNumber{Int}},N} ) where {T,N} diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 66d7ac665..84a69b0c8 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -67,7 +67,9 @@ function get_ancestor_indices( @assert length(indices) == N "Expected $N indices, got $(length(indices))" indices = normalize_indices(x, indices...) if any(is_traced, indices) - indices, integer_indices, result_size, flattened_size = traced_indices(indices...) + indices, integer_indices, result_size, _, flattened_size = traced_indices( + indices... + ) linear_indices = mapreduce(+, enumerate(indices)) do (i, idx) bcasted_idxs = Ops.broadcast_in_dim( idx, ndims(idx) == 0 ? Int64[] : Int64[i], flattened_size @@ -704,18 +706,27 @@ end function traced_indices(indices...) integer_indices = Int64[] result_size = Int64[] + preddim_result_size = Int64[] flattened_size = Int64[length(idx) for idx in indices] new_indices = map(enumerate(indices)) do (i, idx) if idx isa Number + push!(preddim_result_size, 1) push!(integer_indices, i) idx isa TracedRNumber && return idx return promote_to(TracedRNumber{Int}, idx) end + append!(preddim_result_size, [size(idx)...]) append!(result_size, [size(idx)...]) idx isa TracedRArray && return materialize_traced_array(vec(idx)) return promote_to(TracedRArray{Int,1}, vec(idx)) end - return new_indices, Tuple(integer_indices), result_size, flattened_size + return ( + new_indices, + Tuple(integer_indices), + result_size, + preddim_result_size, + flattened_size, + ) end end diff --git a/test/wrapped_arrays.jl b/test/wrapped_arrays.jl index da8ae34eb..87915ffe0 100644 --- a/test/wrapped_arrays.jl +++ b/test/wrapped_arrays.jl @@ -243,3 +243,24 @@ end x_ra = Reactant.to_rarray(rand(3, 4, 3)) @test @jit(fn(x_ra)) == fn(Array(x_ra)) end + +function reshape_getindex(x) + x = reshape(x, 2, 4, 3) + return x[1, :, :] +end + +function permutedims_getindex(x) + x = PermutedDimsArray(x, (2, 1)) + return x[1, :] +end + +@testset "no gather getindex" begin + x = ones(8, 3) + x_ra = Reactant.to_rarray(x) + + hlo = repr(@code_hlo(reshape_getindex(x_ra))) + @test !occursin("stablehlo.gather", hlo) + + hlo = repr(@code_hlo(permutedims_getindex(x_ra))) + @test !occursin("stablehlo.gather", hlo) +end