diff --git a/src/oneelement.jl b/src/oneelement.jl index c72bcd2d..ec9791cb 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -42,10 +42,25 @@ OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz) Base.size(A::OneElement) = map(length, A.axes) Base.axes(A::OneElement) = A.axes +Base.getindex(A::OneElement{T,0}) where {T} = getindex_value(A) Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N} @boundscheck checkbounds(A, kj...) ifelse(kj == A.ind, A.val, zero(T)) end +const VectorInds = Union{AbstractUnitRange{<:Integer}, Integer} # no index is repeated for these indices +const VectorIndsWithColon = Union{VectorInds, Colon} +# retain the values from Ainds corresponding to the vector indices in inds +_index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds)) +_index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...) +_index_shape(::Tuple{}, ::Tuple{}) = () +Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} + I = to_indices(A, inds) # handle Bool, and convert to compatible index types + @boundscheck checkbounds(A, I...) + shape = _index_shape(I, I) + nzind = _index_shape(A.ind, I) .- first.(shape) .+ firstindex.(shape) + containsval = all(in.(A.ind, I)) + OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1)) +end """ nzind(A::OneElement{T,N}) -> CartesianIndex{N} diff --git a/test/runtests.jl b/test/runtests.jl index d4b2bb95..ab39c8b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2148,10 +2148,12 @@ end @test FillArrays.nzind(A) == CartesianIndex() @test A == Fill(2, ()) @test A[] === 2 + @test A[1] === A[1,1] === 2 e₁ = OneElement(2, 5) @test e₁ == [0,1,0,0,0] @test FillArrays.nzind(e₁) == CartesianIndex(2) + @test e₁[2] === e₁[2,1] === e₁[2,1,1] === 1 @test_throws BoundsError e₁[6] f₁ = AbstractArray{Float64}(e₁) @@ -2193,6 +2195,89 @@ end @test A[1,1] === A[1,2] === A[2,1] === zero(S) end + @testset "Vector indexing" begin + @testset "1D" begin + A = OneElement(2, 2, 4) + @test @inferred(A[:]) === @inferred(A[axes(A)...]) === A + @test @inferred(A[3:4]) isa OneElement{Int,1} + @test @inferred(A[3:4]) == Zeros(2) + @test @inferred(A[1:2]) === OneElement(2, 2, 2) + @test @inferred(A[2:3]) === OneElement(2, 1, 2) + @test @inferred(A[Base.IdentityUnitRange(2:3)]) isa OneElement{Int,1} + @test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),)) + @test A[:,:] == reshape(A, size(A)..., 1) + + @test A[reverse(axes(A,1))] == A[collect(reverse(axes(A,1)))] + + @testset "repeated indices" begin + @test A[StepRangeLen(2, 0, 3)] == A[fill(2, 3)] + end + + B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),)) + @test @inferred(A[:]) === @inferred(A[axes(A)...]) === A + @test @inferred(A[3:4]) isa OneElement{Int,1} + @test @inferred(A[3:4]) == Zeros(2) + @test @inferred(A[2:3]) === OneElement(2, 1, 2) + + C = OneElement(2, (2,), (Base.OneTo(big(4)),)) + @test @inferred(C[1:4]) === OneElement(2, 2, 4) + + D = OneElement(2, (2,), (InfiniteArrays.OneToInf(),)) + D2 = D[:] + @test axes(D2) == axes(D) + @test D2[2] == D[2] + D3 = D[axes(D)...] + @test axes(D3) == axes(D) + @test D3[2] == D[2] + end + @testset "2D" begin + A = OneElement(2, (2,3), (4,5)) + @test @inferred(A[:,:]) === @inferred(A[axes(A)...]) === A + @test @inferred(A[:,1]) isa OneElement{Int,1} + @test @inferred(A[:,1]) == Zeros(4) + @test A[:, Int64(1)] === A[:, Int32(1)] + @test @inferred(A[1,:]) isa OneElement{Int,1} + @test @inferred(A[1,:]) == Zeros(5) + @test @inferred(A[:,3]) === OneElement(2, 2, 4) + @test @inferred(A[2,:]) === OneElement(2, 3, 5) + @test @inferred(A[1:1,:]) isa OneElement{Int,2} + @test @inferred(A[1:1,:]) == Zeros(1,5) + @test @inferred(A[4:4,:]) isa OneElement{Int,2} + @test @inferred(A[4:4,:]) == Zeros(1,5) + @test @inferred(A[2:2,:]) === OneElement(2, (1,3), (1,5)) + @test @inferred(A[1:4,:]) === OneElement(2, (2,3), (4,5)) + @test @inferred(A[:,3:3]) === OneElement(2, (2,1), (4,1)) + @test @inferred(A[:,1:5]) === OneElement(2, (2,3), (4,5)) + @test @inferred(A[1:4,1:4]) === OneElement(2, (2,3), (4,4)) + @test @inferred(A[2:4,2:4]) === OneElement(2, (1,2), (3,3)) + @test @inferred(A[2:4,3:4]) === OneElement(2, (1,1), (3,2)) + @test @inferred(A[4:4,5:5]) isa OneElement{Int,2} + @test @inferred(A[4:4,5:5]) == Zeros(1,1) + @test @inferred(A[Base.IdentityUnitRange(2:4), :]) isa OneElement{Int,2} + @test axes(A[Base.IdentityUnitRange(2:4), :]) == (Base.IdentityUnitRange(2:4), axes(A,2)) + @test @inferred(A[:,:,:]) == reshape(A, size(A)...,1) + + B = OneElement(2, (2,3), (Base.IdentityUnitRange(2:4),Base.IdentityUnitRange(2:5))) + @test @inferred(B[:,:]) === @inferred(B[axes(B)...]) === B + @test @inferred(B[:,3]) === OneElement(2, (2,), (Base.IdentityUnitRange(2:4),)) + @test @inferred(B[3:4, 4:5]) isa OneElement{Int,2} + @test @inferred(B[3:4, 4:5]) == Zeros(2,2) + b = @inferred(B[Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(4:5)]) + @test b == Zeros(axes(b)) + + C = OneElement(2, (2,3), (Base.OneTo(big(4)), Base.OneTo(big(5)))) + @test @inferred(C[1:4, 1:5]) === OneElement(2, (2,3), Int.(size(C))) + + D = OneElement(2, (2,3), (InfiniteArrays.OneToInf(), InfiniteArrays.OneToInf())) + D2 = @inferred D[:,:] + @test axes(D2) == axes(D) + @test D2[2,3] == D[2,3] + D3 = @inferred D[axes(D)...] + @test axes(D3) == axes(D) + @test D3[2,3] == D[2,3] + end + end + @testset "adjoint/transpose" begin A = OneElement(3im, (2,4), (4,6)) @test A' === OneElement(-3im, (4,2), (6,4))