Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing SOneTo with IdentityUnitRange preserves indices #922

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/SOneTo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ end
@boundscheck checkbounds(s, s2)
return s2
end
if isdefined(Base, :IdentityUnitRange)
@propagate_inbounds function Base.getindex(s::SOneTo, s2::Base.IdentityUnitRange{<:AbstractUnitRange{<:Integer}})
@boundscheck checkbounds(s, s2)
return s2
end
Base.axes(::Base.IdentityUnitRange{A}) where {A <: SOneTo} = (A(),)
Base.axes(r::Base.IdentityUnitRange{<:SOneTo}, d::Int) = d <= 1 ? axes(r)[d] : SOneTo(1)
Base.axes1(r::Base.IdentityUnitRange{A}) where {A <: SOneTo} = A()
Base.unsafe_indices(::Base.IdentityUnitRange{A}) where {A <: SOneTo} = (A(),)
end

Base.first(::SOneTo) = 1
Base.last(::SOneTo{n}) where {n} = n::Int
Expand Down
17 changes: 16 additions & 1 deletion src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@ end
@inline index_size(::Size, a::StaticArray) = Size(a)
@inline index_size(s::Size, ::Colon) = s
@inline index_size(s::Size, a::SOneTo{n}) where n = Size(n,)
if isdefined(Base, :IdentityUnitRange)
@inline index_size(s::Size, a::Base.IdentityUnitRange{SOneTo{n}}) where n = Size(n,)
end

@inline index_sizes(::S, inds...) where {S<:Size} = map(index_size, unpack_size(S), inds)
@inline index_sizes(::S, inds) where {S<:Size} = map(index_size, map(Size, linear_index_size(S)), (inds,))

@inline index_sizes() = ()
@inline index_sizes(::Int, inds...) = (Size(), index_sizes(inds...)...)
Expand All @@ -96,6 +100,9 @@ _ind(i::Int, ::Int, ::Type{Int}) = :(inds[$i])
_ind(i::Int, j::Int, ::Type{<:StaticArray}) = :(inds[$i][$j])
_ind(i::Int, j::Int, ::Type{Colon}) = j
_ind(i::Int, j::Int, ::Type{<:SOneTo}) = j
if isdefined(Base, :IdentityUnitRange)
_ind(i::Int, j::Int, ::Type{<:Base.IdentityUnitRange{<:SOneTo}}) = j
end

################################
## Non-scalar linear indexing ##
Expand Down Expand Up @@ -223,7 +230,15 @@ end
# getindex

@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...)
_getindex(a, index_sizes(Size(a), inds...), inds)
ar = reshape(a, Val(length(inds)))
_getindex(ar, index_sizes(Size(ar), inds...), inds)
end

if isdefined(Base, :IdentityUnitRange)
@propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon, Base.IdentityUnitRange{<:SOneTo}}...)
ar = reshape(a, Val(length(inds)))
_getindex(ar, index_sizes(Size(ar), inds...), inds)
end
end

function Base._getindex(::IndexStyle, A::AbstractArray, i1::StaticIndexing, I::StaticIndexing...)
Expand Down
50 changes: 47 additions & 3 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ using StaticArrays, Test, LinearAlgebra
@test r == m[:, 2:3] * v[1:2] == Array(m)[:, 2:3] * Array(v)[1:2]
end

if isdefined(Base, :IdentityUnitRange)
@testset "indexing SOneTo with IdentityUnitRange" begin
s = SOneTo(4)
for r in Any[Base.IdentityUnitRange(2:3), Base.IdentityUnitRange(SOneTo(2))]
si = @inferred s[r]
@test si == r
@test axes(si,1) == axes(r,1)
end
@test_throws BoundsError s[Base.IdentityUnitRange(1:5)]
end
end

@testset "reshape" begin
@test @inferred(reshape(SVector(1,2,3,4), axes(SMatrix{2,2}(1,2,3,4)))) === SMatrix{2,2}(1,2,3,4)
Expand Down Expand Up @@ -199,6 +210,39 @@ using StaticArrays, Test, LinearAlgebra
unitlotri = UnitLowerTriangular(SA[1 0; 2 1])
@test_broken @inferred(convert(AbstractArray{Float64}, unitlotri)) isa UnitLowerTriangular{Float64,SMatrix{2,2,Float64,4}}
end

@testset "views" begin
for a in Any[SVector{2}(1:2), MVector{2}(1:2)]
v = view(a, :)
@test axes(v) === axes(a)
v2 = view(a, SOneTo(1))
@test axes(v2, 1) === SOneTo(1)
if isdefined(Base, :IdentityUnitRange)
v2 = view(a, Base.IdentityUnitRange(SOneTo(1)))
@test axes(v2, 1) === SOneTo(1)
end
end
for a in Any[SMatrix{2,2}(1:4), MMatrix{2,2}(1:4)]
v = view(a, :, :)
@test axes(v) === axes(a)
v2 = view(a, SOneTo(1), SOneTo(1))
@test axes(v2) === (SOneTo(1), SOneTo(1))
if isdefined(Base, :IdentityUnitRange)
v2 = view(a, Base.IdentityUnitRange(SOneTo(1)), Base.IdentityUnitRange(SOneTo(1)))
@test axes(v2) === (SOneTo(1), SOneTo(1))
end
end
end

@testset "SOneTo" begin
if isdefined(Base, :IdentityUnitRange)
s = Base.IdentityUnitRange(SOneTo(3))
@test axes(s) == (SOneTo(3),)
@test axes(s,1) == SOneTo(3)
@test Base.axes1(s) == axes(s,1)
@test Base.unsafe_indices(s) == axes(s)
end
end
end

@testset "vcat() and hcat()" begin
Expand Down Expand Up @@ -280,7 +324,7 @@ end
@test Base.rest(x) == x
a, b... = x
@test b == SA[2, 3]

x = SA[1 2; 3 4]
@test Base.rest(x) == vec(x)
a, b... = x
Expand All @@ -289,14 +333,14 @@ end
a, b... = SA[1]
@test b == []
@test b isa SVector{0}

for (Vec, Mat) in [(MVector, MMatrix), (SizedVector, SizedMatrix)]
x = Vec(1, 2, 3)
@test Base.rest(x) == x
@test pointer(Base.rest(x)) != pointer(x)
a, b... = x
@test b == Vec(2, 3)

x = Mat{2,2}(1, 2, 3, 4)
@test Base.rest(x) == vec(x)
@test pointer(Base.rest(x)) != pointer(x)
Expand Down
139 changes: 114 additions & 25 deletions test/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ using StaticArrays, Test

# SArray
@test (@inferred getindex(sv, SMatrix{2,2}(1,4,2,3))) === SMatrix{2,2}(4,7,5,6)

@test (@inferred getindex(sv, axes(sv, 1))) === sv
if isdefined(Base, :IdentityUnitRange)
@test (@inferred getindex(sv, Base.IdentityUnitRange(axes(sv, 1)))) === sv
end
end

@testset "Linear getindex() on SMatrix" begin
Expand All @@ -21,6 +26,14 @@ using StaticArrays, Test
# SVector
@test (@inferred getindex(sm, SVector(4,3,2,1))) === SVector((7,6,5,4))

# SOneTo
@test (@inferred getindex(sm, SOneTo(length(sm)))) === sv

# IdentityUnitRange{<:SOneTo}
if isdefined(Base, :IdentityUnitRange)
@test (@inferred getindex(sm, Base.IdentityUnitRange(SOneTo(length(sm))))) === sv
end

# Colon
@test (@inferred getindex(sm,:)) === sv

Expand All @@ -29,49 +42,77 @@ using StaticArrays, Test
end

@testset "Linear getindex()/setindex!() on MVector" begin
vec = @SVector [4,5,6,7]
sv = @SVector [4,5,6,7]

# SVector
mv = MVector{4,Int}(undef)
@test (mv[SVector(1,2,3,4)] = vec; (@inferred getindex(mv, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4)))
@test setindex!(mv, vec, SVector(1,2,3,4)) === mv
mvec = MVector{4,Int}(undef)
@test (mvec[SVector(1,2,3,4)] = sv; (@inferred getindex(mvec, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4)))
@test setindex!(mvec, sv, SVector(1,2,3,4)) === mvec

mv = MVector{4,Int}(undef)
@test (mv[SVector(1,2,3,4)] = [4, 5, 6, 7]; (@inferred getindex(mv, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4)))
@test (mv[SVector(1,2,3,4)] = 2; (@inferred getindex(mv, SVector(4,3,2,1)))::MVector{4,Int} == MVector((2,2,2,2)))
mvec = MVector{4,Int}(undef)
@test (mvec[SVector(1,2,3,4)] = [4, 5, 6, 7]; (@inferred getindex(mvec, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4)))
@test (mvec[SVector(1,2,3,4)] = 2; (@inferred getindex(mvec, SVector(4,3,2,1)))::MVector{4,Int} == MVector((2,2,2,2)))

mv = MVector(0,0,0)
@test (mv[SVector(1,3)] = [4, 5]; (@inferred mv == MVector(4,0,5)))
mvec = MVector(0,0,0)
@test (mvec[SVector(1,3)] = [4, 5]; (@inferred mvec == MVector(4,0,5)))

mv = MVector(0,0,0)
@test (mv[SVector(1,3)] = SVector(4, 5); (@inferred mv == MVector(4,0,5)))
mvec = MVector(0,0,0)
@test (mvec[SVector(1,3)] = SVector(4, 5); (@inferred mvec == MVector(4,0,5)))

mv = MVector(0,0,0)
@test (mv[SMatrix{2,1}(1,3)] = SMatrix{2,1}(4, 5); (@inferred mv == MVector(4,0,5)))
mvec = MVector(0,0,0)
@test (mvec[SMatrix{2,1}(1,3)] = SMatrix{2,1}(4, 5); (@inferred mvec == MVector(4,0,5)))

# Colon
mv = MVector{4,Int}(undef)
@test (mv[:] = vec; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((4,5,6,7)))
@test (mv[:] = [4, 5, 6, 7]; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((4,5,6,7)))
@test (mv[:] = 2; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((2,2,2,2)))
@test setindex!(mv, 2, :) === mv
mvec = MVector{4,Int}(undef)
@test (mvec[:] = sv; (@inferred getindex(mvec, :))::MVector{4,Int} == MVector((4,5,6,7)))
@test (mvec[:] = [4, 5, 6, 7]; (@inferred getindex(mvec, :))::MVector{4,Int} == MVector((4,5,6,7)))
@test (mvec[:] = 2; (@inferred getindex(mvec, :))::MVector{4,Int} == MVector((2,2,2,2)))
@test setindex!(mvec, 2, :) === mvec

@test_throws DimensionMismatch setindex!(mv, SVector(1,2,3), SVector(1,2,3,4))
@test_throws DimensionMismatch setindex!(mv, SVector(1,2,3), :)
@test_throws DimensionMismatch setindex!(mv, view(ones(8), 1:5), :)
@test_throws DimensionMismatch setindex!(mv, [1,2,3], SVector(1,2,3,4))
# SOneTo
@test begin
mvec[SOneTo(length(mvec))] = sv
(@inferred mvec[SOneTo(length(mvec))]) == sv
end

# IdentityUnitRange{<:SOneTo}
if isdefined(Base, :IdentityUnitRange)
@test begin
mvec[Base.IdentityUnitRange(SOneTo(length(mvec)))] = sv
(@inferred mvec[Base.IdentityUnitRange(SOneTo(length(mvec)))]) == sv
end
end

@test_throws DimensionMismatch setindex!(mvec, SVector(1,2,3), SVector(1,2,3,4))
@test_throws DimensionMismatch setindex!(mvec, SVector(1,2,3), :)
@test_throws DimensionMismatch setindex!(mvec, view(ones(8), 1:5), :)
@test_throws DimensionMismatch setindex!(mvec, [1,2,3], SVector(1,2,3,4))
end

@testset "Linear getindex()/setindex!() on MMatrix" begin
vec = @SVector [4,5,6,7]
sv = @SVector [4,5,6,7]

# SVector
mm = MMatrix{2,2,Int}(undef)
@test (mm[SVector(1,2,3,4)] = vec; (@inferred getindex(mm, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4)))
@test (mm[SVector(1,2,3,4)] = sv; (@inferred getindex(mm, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4)))

# Colon
mm = MMatrix{2,2,Int}(undef)
@test (mm[:] = vec; (@inferred getindex(mm, :))::MVector{4,Int} == MVector((4,5,6,7)))
@test (mm[:] = sv; (@inferred getindex(mm, :))::MVector{4,Int} == MVector((4,5,6,7)))

# SOneTo
@test begin
mm[SOneTo(length(mm))] = sv
(@inferred mm[SOneTo(length(mm))]) == sv
end

# IdentityUnitRange{<:SOneTo}
if isdefined(Base, :IdentityUnitRange)
@test begin
mm[Base.IdentityUnitRange(SOneTo(length(mm)))] = sv
(@inferred mm[Base.IdentityUnitRange(SOneTo(length(mm)))]) == sv
end
end

# SMatrix
mm = MMatrix{2,2,Int}(undef)
Expand All @@ -96,6 +137,12 @@ using StaticArrays, Test
@test v[2,1] == 2
@test_throws BoundsError v[1,2]
@test_throws BoundsError v[3,1]

# SOneTo
@test (@inferred v[axes(v,1), SOneTo(1)]) === SMatrix{2,1}(v)
@test v[axes(v,1), SOneTo(1)] == v[Base.OneTo(length(v)), Base.OneTo(1)]
@test (@inferred v[axes(v,1), 1, SOneTo(1)]) === SMatrix{2,1}(v)
@test v[axes(v,1), 1, SOneTo(1)] == v[Base.OneTo(length(v)), 1, Base.OneTo(1)]
end

@testset "2D getindex() on SMatrix" begin
Expand All @@ -122,6 +169,13 @@ using StaticArrays, Test
# SOneTo
@testinf sm[SOneTo(1),:] === @SMatrix [1 3]
@testinf sm[:,SOneTo(1)] === @SMatrix [1;2]

# IdentityUnitRange{<:SOneTo}
if isdefined(Base, :IdentityUnitRange)
@test (@inferred sm[Base.IdentityUnitRange(axes(sm, 1)), :]) === sm
@test (@inferred sm[Base.IdentityUnitRange(axes(sm, 1)), SOneTo(1)]) === SMatrix{2,1}(sm[:,1])
@test (@inferred sm[Base.IdentityUnitRange(axes(sm, 1)), SVector{1}(SOneTo(1))]) === SMatrix{2,1}(sm[:,1])
end
end

@testset "2D getindex()/setindex! on MMatrix" begin
Expand All @@ -143,6 +197,18 @@ using StaticArrays, Test
@test (mm = MMatrix{2,2,Int}(undef); mm[SOneTo(1),:] = sm[SOneTo(1),:]; (@inferred getindex(mm, SOneTo(1), :))::MMatrix == @MMatrix [1 3])
@test (mm = MMatrix{2,2,Int}(undef); mm[:,SOneTo(1)] = sm[:,SOneTo(1)]; (@inferred getindex(mm, :, SOneTo(1)))::MMatrix == @MMatrix [1;2])

# IdentityUnitRange{<:SOneTo}
if isdefined(Base, :IdentityUnitRange)
@test begin
mm = MMatrix{2,2,Int}(undef);
mm[map(Base.IdentityUnitRange, axes(mm))...] = sm
(@inferred mm[map(Base.IdentityUnitRange, axes(mm))...]) == mm
(@inferred mm[Base.IdentityUnitRange(axes(mm,1)), :]) == mm
(@inferred mm[Base.IdentityUnitRange(axes(mm,1)), axes(mm,2)]) == mm
(@inferred mm[Base.IdentityUnitRange(axes(mm,1)), SVector{2}(axes(mm,2))]) == mm
end
end

# #866
@test_throws DimensionMismatch setindex!(MMatrix(SA[1 2; 3 4]), SA[3,4], 1, SA[1,2,3])
@test_throws DimensionMismatch setindex!(MMatrix(SA[1 2; 3 4]), [3,4], 1, SA[1,2,3])
Expand Down Expand Up @@ -189,6 +255,29 @@ using StaticArrays, Test
@test (@inferred getindex(a, SVector(1,2), 1, 1, 1)) == [24,48]
end

@testset "indexing with reshape for SMatrix/MMatrix" begin
sm = @SMatrix [1 3; 2 4]
mm = @MMatrix [1 3; 2 4]
for m in Any[sm, mm, view(sm, :, :), view(mm, :, :)]
sa = @inferred m[:, SOneTo(1), 1, SOneTo(1)]
a = m[:, Base.OneTo(1), 1, Base.OneTo(1)]
@test sa == a
@test sa == SArray{Tuple{2,1,1}}(a)
if m isa SArray
@test sa === SArray{Tuple{2,1,1}}(a)
end

if isdefined(Base, :IdentityUnitRange)
sa = @inferred m[:, Base.IdentityUnitRange(SOneTo(1)), 1, SOneTo(1)]
@test sa == a
@test sa == SArray{Tuple{2,1,1}}(a)
if m isa SArray
@test sa === SArray{Tuple{2,1,1}}(a)
end
end
end
end

@testset "Indexing with empty vectors" begin
a = [1.0 2.0; 3.0 4.0]
@test a[SVector{0,Int}()] == SVector{0,Float64}(())
Expand Down