From b1f8d93842056c7bf28c75b016eaf738e5dac9be Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 1 Feb 2024 12:11:49 +0530 Subject: [PATCH 1/3] Vector indexing for OneElement --- src/oneelement.jl | 17 +++++++++++ test/runtests.jl | 78 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/src/oneelement.jl b/src/oneelement.jl index c72bcd2d..7d1514c4 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -42,10 +42,27 @@ OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz) Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes +Base.getindex(A::OneElement{T,0}) where {T} = getindex_value(A) Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N} @boundscheck checkbounds(A, kj...) ifelse(kj == A.ind, A.val, zero(T)) end +const VectorIndsWithColon = Union{AbstractRange{Int}, Colon, Int} +const VectorInds = Union{AbstractRange{Int}, Int} +# retain the values from Ainds corresponding to the vector indices in inds +_index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds)) +_index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...) +_index_shape(::Tuple{}, ::Tuple{}) = () +@inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N} + @boundscheck checkbounds(A, inds...) + shape = _index_shape(inds, inds) + nzind = _index_shape(A.ind, inds) .- first.(shape) .+ firstindex.(shape) + containsval = all(in.(A.ind, inds)) + OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1)) +end +Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} + getindex(A, Base.to_indices(A, inds)...) +end """ nzind(A::OneElement{T,N}) -> CartesianIndex{N} diff --git a/test/runtests.jl b/test/runtests.jl index d4b2bb95..d08cf847 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2148,10 +2148,12 @@ end @test FillArrays.nzind(A) == CartesianIndex() @test A == Fill(2, ()) @test A[] === 2 + @test A[1] === A[1,1] === 2 e₁ = OneElement(2, 5) @test e₁ == [0,1,0,0,0] @test FillArrays.nzind(e₁) == CartesianIndex(2) + @test e₁[2] === e₁[2,1] === e₁[2,1,1] === 1 @test_throws BoundsError e₁[6] f₁ = AbstractArray{Float64}(e₁) @@ -2193,6 +2195,82 @@ end @test A[1,1] === A[1,2] === A[2,1] === zero(S) end + @testset "Vector indexing" begin + @testset "1D" begin + A = OneElement(2, 2, 4) + @test @inferred(A[:]) === @inferred(A[axes(A)...]) === A + @test @inferred(A[3:4]) isa OneElement{Int,1} + @test @inferred(A[3:4]) == Zeros(2) + @test @inferred(A[1:2]) === OneElement(2, 2, 2) + @test @inferred(A[2:3]) === OneElement(2, 1, 2) + @test @inferred(A[Base.IdentityUnitRange(2:3)]) isa OneElement{Int,1} + @test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),)) + @test A[:,:] == reshape(A, size(A)..., 1) + + B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),)) + @test @inferred(A[:]) === @inferred(A[axes(A)...]) === A + @test @inferred(A[3:4]) isa OneElement{Int,1} + @test @inferred(A[3:4]) == Zeros(2) + @test @inferred(A[2:3]) === OneElement(2, 1, 2) + + C = OneElement(2, (2,), (Base.OneTo(big(4)),)) + @test @inferred(C[1:4]) === OneElement(2, 2, 4) + + D = OneElement(2, (2,), (InfiniteArrays.OneToInf(),)) + D2 = D[:] + @test axes(D2) == axes(D) + @test D2[2] == D[2] + D3 = D[axes(D)...] + @test axes(D3) == axes(D) + @test D3[2] == D[2] + end + @testset "2D" begin + A = OneElement(2, (2,3), (4,5)) + @test @inferred(A[:,:]) === @inferred(A[axes(A)...]) === A + @test @inferred(A[:,1]) isa OneElement{Int,1} + @test @inferred(A[:,1]) == Zeros(4) + @test @inferred(A[1,:]) isa OneElement{Int,1} + @test @inferred(A[1,:]) == Zeros(5) + @test @inferred(A[:,3]) === OneElement(2, 2, 4) + @test @inferred(A[2,:]) === OneElement(2, 3, 5) + @test @inferred(A[1:1,:]) isa OneElement{Int,2} + @test @inferred(A[1:1,:]) == Zeros(1,5) + @test @inferred(A[4:4,:]) isa OneElement{Int,2} + @test @inferred(A[4:4,:]) == Zeros(1,5) + @test @inferred(A[2:2,:]) === OneElement(2, (1,3), (1,5)) + @test @inferred(A[1:4,:]) === OneElement(2, (2,3), (4,5)) + @test @inferred(A[:,3:3]) === OneElement(2, (2,1), (4,1)) + @test @inferred(A[:,1:5]) === OneElement(2, (2,3), (4,5)) + @test @inferred(A[1:4,1:4]) === OneElement(2, (2,3), (4,4)) + @test @inferred(A[2:4,2:4]) === OneElement(2, (1,2), (3,3)) + @test @inferred(A[2:4,3:4]) === OneElement(2, (1,1), (3,2)) + @test @inferred(A[4:4,5:5]) isa OneElement{Int,2} + @test @inferred(A[4:4,5:5]) == Zeros(1,1) + @test @inferred(A[Base.IdentityUnitRange(2:4), :]) isa OneElement{Int,2} + @test axes(A[Base.IdentityUnitRange(2:4), :]) == (Base.IdentityUnitRange(2:4), axes(A,2)) + @test @inferred(A[:,:,:]) == reshape(A, size(A)...,1) + + B = OneElement(2, (2,3), (Base.IdentityUnitRange(2:4),Base.IdentityUnitRange(2:5))) + @test @inferred(B[:,:]) === @inferred(B[axes(B)...]) === B + @test @inferred(B[:,3]) === OneElement(2, (2,), (Base.IdentityUnitRange(2:4),)) + @test @inferred(B[3:4, 4:5]) isa OneElement{Int,2} + @test @inferred(B[3:4, 4:5]) == Zeros(2,2) + b = @inferred(B[Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(4:5)]) + @test b == Zeros(axes(b)) + + C = OneElement(2, (2,3), (Base.OneTo(big(4)), Base.OneTo(big(5)))) + @test @inferred(C[1:4, 1:5]) === OneElement(2, (2,3), Int.(size(C))) + + D = OneElement(2, (2,3), (InfiniteArrays.OneToInf(), InfiniteArrays.OneToInf())) + D2 = @inferred D[:,:] + @test axes(D2) == axes(D) + @test D2[2,3] == D[2,3] + D3 = @inferred D[axes(D)...] + @test axes(D3) == axes(D) + @test D3[2,3] == D[2,3] + end + end + @testset "adjoint/transpose" begin A = OneElement(3im, (2,4), (4,6)) @test A' === OneElement(-3im, (4,2), (6,4)) From beb758d5ada8ac4bea7a789f591d9d9ffb95f81e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 9 May 2024 11:56:53 +0530 Subject: [PATCH 2/3] Handle non-int Integer indices --- src/oneelement.jl | 15 ++++++++------- test/runtests.jl | 1 + 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index 7d1514c4..9011a3c3 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -47,21 +47,22 @@ Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{I @boundscheck checkbounds(A, kj...) ifelse(kj == A.ind, A.val, zero(T)) end -const VectorIndsWithColon = Union{AbstractRange{Int}, Colon, Int} -const VectorInds = Union{AbstractRange{Int}, Int} +const VectorInds = Union{AbstractRange{<:Integer}, Integer} +const VectorIndsWithColon = Union{VectorInds, Colon} # retain the values from Ainds corresponding to the vector indices in inds _index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds)) _index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...) _index_shape(::Tuple{}, ::Tuple{}) = () @inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N} - @boundscheck checkbounds(A, inds...) - shape = _index_shape(inds, inds) - nzind = _index_shape(A.ind, inds) .- first.(shape) .+ firstindex.(shape) - containsval = all(in.(A.ind, inds)) + I = to_indices(A, inds) # handle Bool, and convert to compatible index types (Int usually) + @boundscheck checkbounds(A, I...) + shape = _index_shape(I, I) + nzind = _index_shape(A.ind, I) .- first.(shape) .+ firstindex.(shape) + containsval = all(in.(A.ind, I)) OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1)) end Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} - getindex(A, Base.to_indices(A, inds)...) + getindex(A, to_indices(A, inds)...) end """ diff --git a/test/runtests.jl b/test/runtests.jl index d08cf847..8098e378 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2229,6 +2229,7 @@ end @test @inferred(A[:,:]) === @inferred(A[axes(A)...]) === A @test @inferred(A[:,1]) isa OneElement{Int,1} @test @inferred(A[:,1]) == Zeros(4) + @test A[:, Int64(1)] === A[:, Int32(1)] @test @inferred(A[1,:]) isa OneElement{Int,1} @test @inferred(A[1,:]) == Zeros(5) @test @inferred(A[:,3]) === OneElement(2, 2, 4) From 3e5c4a2d1921cf8e957d9c4eef3997b3e3c21464 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 9 May 2024 13:13:00 +0530 Subject: [PATCH 3/3] Restrict to AbstractUnitRanges to avoid repeated indices --- src/oneelement.jl | 9 +++------ test/runtests.jl | 6 ++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/oneelement.jl b/src/oneelement.jl index 9011a3c3..ec9791cb 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -47,23 +47,20 @@ Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{I @boundscheck checkbounds(A, kj...) ifelse(kj == A.ind, A.val, zero(T)) end -const VectorInds = Union{AbstractRange{<:Integer}, Integer} +const VectorInds = Union{AbstractUnitRange{<:Integer}, Integer} # no index is repeated for these indices const VectorIndsWithColon = Union{VectorInds, Colon} # retain the values from Ainds corresponding to the vector indices in inds _index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds)) _index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...) _index_shape(::Tuple{}, ::Tuple{}) = () -@inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N} - I = to_indices(A, inds) # handle Bool, and convert to compatible index types (Int usually) +Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} + I = to_indices(A, inds) # handle Bool, and convert to compatible index types @boundscheck checkbounds(A, I...) shape = _index_shape(I, I) nzind = _index_shape(A.ind, I) .- first.(shape) .+ firstindex.(shape) containsval = all(in.(A.ind, I)) OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1)) end -Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} - getindex(A, to_indices(A, inds)...) -end """ nzind(A::OneElement{T,N}) -> CartesianIndex{N} diff --git a/test/runtests.jl b/test/runtests.jl index 8098e378..ab39c8b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2207,6 +2207,12 @@ end @test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),)) @test A[:,:] == reshape(A, size(A)..., 1) + @test A[reverse(axes(A,1))] == A[collect(reverse(axes(A,1)))] + + @testset "repeated indices" begin + @test A[StepRangeLen(2, 0, 3)] == A[fill(2, 3)] + end + B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),)) @test @inferred(A[:]) === @inferred(A[axes(A)...]) === A @test @inferred(A[3:4]) isa OneElement{Int,1}