diff --git a/src/SOneTo.jl b/src/SOneTo.jl index d2a4a7b2..e74cd9df 100644 --- a/src/SOneTo.jl +++ b/src/SOneTo.jl @@ -34,6 +34,16 @@ end @boundscheck checkbounds(s, s2) return s2 end +if isdefined(Base, :IdentityUnitRange) + @propagate_inbounds function Base.getindex(s::SOneTo, s2::Base.IdentityUnitRange{<:AbstractUnitRange{<:Integer}}) + @boundscheck checkbounds(s, s2) + return s2 + end + Base.axes(::Base.IdentityUnitRange{A}) where {A <: SOneTo} = (A(),) + Base.axes(r::Base.IdentityUnitRange{<:SOneTo}, d::Int) = d <= 1 ? axes(r)[d] : SOneTo(1) + Base.axes1(r::Base.IdentityUnitRange{A}) where {A <: SOneTo} = A() + Base.unsafe_indices(::Base.IdentityUnitRange{A}) where {A <: SOneTo} = (A(),) +end Base.first(::SOneTo) = 1 Base.last(::SOneTo{n}) where {n} = n::Int diff --git a/src/indexing.jl b/src/indexing.jl index dffb0cef..6085f9f4 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -77,8 +77,12 @@ end @inline index_size(::Size, a::StaticArray) = Size(a) @inline index_size(s::Size, ::Colon) = s @inline index_size(s::Size, a::SOneTo{n}) where n = Size(n,) +if isdefined(Base, :IdentityUnitRange) + @inline index_size(s::Size, a::Base.IdentityUnitRange{SOneTo{n}}) where n = Size(n,) +end @inline index_sizes(::S, inds...) where {S<:Size} = map(index_size, unpack_size(S), inds) +@inline index_sizes(::S, inds) where {S<:Size} = map(index_size, map(Size, linear_index_size(S)), (inds,)) @inline index_sizes() = () @inline index_sizes(::Int, inds...) = (Size(), index_sizes(inds...)...) @@ -96,6 +100,9 @@ _ind(i::Int, ::Int, ::Type{Int}) = :(inds[$i]) _ind(i::Int, j::Int, ::Type{<:StaticArray}) = :(inds[$i][$j]) _ind(i::Int, j::Int, ::Type{Colon}) = j _ind(i::Int, j::Int, ::Type{<:SOneTo}) = j +if isdefined(Base, :IdentityUnitRange) + _ind(i::Int, j::Int, ::Type{<:Base.IdentityUnitRange{<:SOneTo}}) = j +end ################################ ## Non-scalar linear indexing ## @@ -223,7 +230,15 @@ end # getindex @propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon}...) - _getindex(a, index_sizes(Size(a), inds...), inds) + ar = reshape(a, Val(length(inds))) + _getindex(ar, index_sizes(Size(ar), inds...), inds) +end + +if isdefined(Base, :IdentityUnitRange) + @propagate_inbounds function getindex(a::StaticArray, inds::Union{Int, StaticArray{<:Tuple, Int}, SOneTo, Colon, Base.IdentityUnitRange{<:SOneTo}}...) + ar = reshape(a, Val(length(inds))) + _getindex(ar, index_sizes(Size(ar), inds...), inds) + end end function Base._getindex(::IndexStyle, A::AbstractArray, i1::StaticIndexing, I::StaticIndexing...) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 21d0cc8d..2fa97f68 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -104,6 +104,17 @@ using StaticArrays, Test, LinearAlgebra @test r == m[:, 2:3] * v[1:2] == Array(m)[:, 2:3] * Array(v)[1:2] end + if isdefined(Base, :IdentityUnitRange) + @testset "indexing SOneTo with IdentityUnitRange" begin + s = SOneTo(4) + for r in Any[Base.IdentityUnitRange(2:3), Base.IdentityUnitRange(SOneTo(2))] + si = @inferred s[r] + @test si == r + @test axes(si,1) == axes(r,1) + end + @test_throws BoundsError s[Base.IdentityUnitRange(1:5)] + end + end @testset "reshape" begin @test @inferred(reshape(SVector(1,2,3,4), axes(SMatrix{2,2}(1,2,3,4)))) === SMatrix{2,2}(1,2,3,4) @@ -199,6 +210,39 @@ using StaticArrays, Test, LinearAlgebra unitlotri = UnitLowerTriangular(SA[1 0; 2 1]) @test_broken @inferred(convert(AbstractArray{Float64}, unitlotri)) isa UnitLowerTriangular{Float64,SMatrix{2,2,Float64,4}} end + + @testset "views" begin + for a in Any[SVector{2}(1:2), MVector{2}(1:2)] + v = view(a, :) + @test axes(v) === axes(a) + v2 = view(a, SOneTo(1)) + @test axes(v2, 1) === SOneTo(1) + if isdefined(Base, :IdentityUnitRange) + v2 = view(a, Base.IdentityUnitRange(SOneTo(1))) + @test axes(v2, 1) === SOneTo(1) + end + end + for a in Any[SMatrix{2,2}(1:4), MMatrix{2,2}(1:4)] + v = view(a, :, :) + @test axes(v) === axes(a) + v2 = view(a, SOneTo(1), SOneTo(1)) + @test axes(v2) === (SOneTo(1), SOneTo(1)) + if isdefined(Base, :IdentityUnitRange) + v2 = view(a, Base.IdentityUnitRange(SOneTo(1)), Base.IdentityUnitRange(SOneTo(1))) + @test axes(v2) === (SOneTo(1), SOneTo(1)) + end + end + end + + @testset "SOneTo" begin + if isdefined(Base, :IdentityUnitRange) + s = Base.IdentityUnitRange(SOneTo(3)) + @test axes(s) == (SOneTo(3),) + @test axes(s,1) == SOneTo(3) + @test Base.axes1(s) == axes(s,1) + @test Base.unsafe_indices(s) == axes(s) + end + end end @testset "vcat() and hcat()" begin @@ -280,7 +324,7 @@ end @test Base.rest(x) == x a, b... = x @test b == SA[2, 3] - + x = SA[1 2; 3 4] @test Base.rest(x) == vec(x) a, b... = x @@ -289,14 +333,14 @@ end a, b... = SA[1] @test b == [] @test b isa SVector{0} - + for (Vec, Mat) in [(MVector, MMatrix), (SizedVector, SizedMatrix)] x = Vec(1, 2, 3) @test Base.rest(x) == x @test pointer(Base.rest(x)) != pointer(x) a, b... = x @test b == Vec(2, 3) - + x = Mat{2,2}(1, 2, 3, 4) @test Base.rest(x) == vec(x) @test pointer(Base.rest(x)) != pointer(x) diff --git a/test/indexing.jl b/test/indexing.jl index c7400018..0bd7dc3c 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -12,6 +12,11 @@ using StaticArrays, Test # SArray @test (@inferred getindex(sv, SMatrix{2,2}(1,4,2,3))) === SMatrix{2,2}(4,7,5,6) + + @test (@inferred getindex(sv, axes(sv, 1))) === sv + if isdefined(Base, :IdentityUnitRange) + @test (@inferred getindex(sv, Base.IdentityUnitRange(axes(sv, 1)))) === sv + end end @testset "Linear getindex() on SMatrix" begin @@ -21,6 +26,14 @@ using StaticArrays, Test # SVector @test (@inferred getindex(sm, SVector(4,3,2,1))) === SVector((7,6,5,4)) + # SOneTo + @test (@inferred getindex(sm, SOneTo(length(sm)))) === sv + + # IdentityUnitRange{<:SOneTo} + if isdefined(Base, :IdentityUnitRange) + @test (@inferred getindex(sm, Base.IdentityUnitRange(SOneTo(length(sm))))) === sv + end + # Colon @test (@inferred getindex(sm,:)) === sv @@ -29,49 +42,77 @@ using StaticArrays, Test end @testset "Linear getindex()/setindex!() on MVector" begin - vec = @SVector [4,5,6,7] + sv = @SVector [4,5,6,7] # SVector - mv = MVector{4,Int}(undef) - @test (mv[SVector(1,2,3,4)] = vec; (@inferred getindex(mv, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4))) - @test setindex!(mv, vec, SVector(1,2,3,4)) === mv + mvec = MVector{4,Int}(undef) + @test (mvec[SVector(1,2,3,4)] = sv; (@inferred getindex(mvec, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4))) + @test setindex!(mvec, sv, SVector(1,2,3,4)) === mvec - mv = MVector{4,Int}(undef) - @test (mv[SVector(1,2,3,4)] = [4, 5, 6, 7]; (@inferred getindex(mv, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4))) - @test (mv[SVector(1,2,3,4)] = 2; (@inferred getindex(mv, SVector(4,3,2,1)))::MVector{4,Int} == MVector((2,2,2,2))) + mvec = MVector{4,Int}(undef) + @test (mvec[SVector(1,2,3,4)] = [4, 5, 6, 7]; (@inferred getindex(mvec, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4))) + @test (mvec[SVector(1,2,3,4)] = 2; (@inferred getindex(mvec, SVector(4,3,2,1)))::MVector{4,Int} == MVector((2,2,2,2))) - mv = MVector(0,0,0) - @test (mv[SVector(1,3)] = [4, 5]; (@inferred mv == MVector(4,0,5))) + mvec = MVector(0,0,0) + @test (mvec[SVector(1,3)] = [4, 5]; (@inferred mvec == MVector(4,0,5))) - mv = MVector(0,0,0) - @test (mv[SVector(1,3)] = SVector(4, 5); (@inferred mv == MVector(4,0,5))) + mvec = MVector(0,0,0) + @test (mvec[SVector(1,3)] = SVector(4, 5); (@inferred mvec == MVector(4,0,5))) - mv = MVector(0,0,0) - @test (mv[SMatrix{2,1}(1,3)] = SMatrix{2,1}(4, 5); (@inferred mv == MVector(4,0,5))) + mvec = MVector(0,0,0) + @test (mvec[SMatrix{2,1}(1,3)] = SMatrix{2,1}(4, 5); (@inferred mvec == MVector(4,0,5))) # Colon - mv = MVector{4,Int}(undef) - @test (mv[:] = vec; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((4,5,6,7))) - @test (mv[:] = [4, 5, 6, 7]; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((4,5,6,7))) - @test (mv[:] = 2; (@inferred getindex(mv, :))::MVector{4,Int} == MVector((2,2,2,2))) - @test setindex!(mv, 2, :) === mv + mvec = MVector{4,Int}(undef) + @test (mvec[:] = sv; (@inferred getindex(mvec, :))::MVector{4,Int} == MVector((4,5,6,7))) + @test (mvec[:] = [4, 5, 6, 7]; (@inferred getindex(mvec, :))::MVector{4,Int} == MVector((4,5,6,7))) + @test (mvec[:] = 2; (@inferred getindex(mvec, :))::MVector{4,Int} == MVector((2,2,2,2))) + @test setindex!(mvec, 2, :) === mvec - @test_throws DimensionMismatch setindex!(mv, SVector(1,2,3), SVector(1,2,3,4)) - @test_throws DimensionMismatch setindex!(mv, SVector(1,2,3), :) - @test_throws DimensionMismatch setindex!(mv, view(ones(8), 1:5), :) - @test_throws DimensionMismatch setindex!(mv, [1,2,3], SVector(1,2,3,4)) + # SOneTo + @test begin + mvec[SOneTo(length(mvec))] = sv + (@inferred mvec[SOneTo(length(mvec))]) == sv + end + + # IdentityUnitRange{<:SOneTo} + if isdefined(Base, :IdentityUnitRange) + @test begin + mvec[Base.IdentityUnitRange(SOneTo(length(mvec)))] = sv + (@inferred mvec[Base.IdentityUnitRange(SOneTo(length(mvec)))]) == sv + end + end + + @test_throws DimensionMismatch setindex!(mvec, SVector(1,2,3), SVector(1,2,3,4)) + @test_throws DimensionMismatch setindex!(mvec, SVector(1,2,3), :) + @test_throws DimensionMismatch setindex!(mvec, view(ones(8), 1:5), :) + @test_throws DimensionMismatch setindex!(mvec, [1,2,3], SVector(1,2,3,4)) end @testset "Linear getindex()/setindex!() on MMatrix" begin - vec = @SVector [4,5,6,7] + sv = @SVector [4,5,6,7] # SVector mm = MMatrix{2,2,Int}(undef) - @test (mm[SVector(1,2,3,4)] = vec; (@inferred getindex(mm, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4))) + @test (mm[SVector(1,2,3,4)] = sv; (@inferred getindex(mm, SVector(4,3,2,1)))::MVector{4,Int} == MVector((7,6,5,4))) # Colon mm = MMatrix{2,2,Int}(undef) - @test (mm[:] = vec; (@inferred getindex(mm, :))::MVector{4,Int} == MVector((4,5,6,7))) + @test (mm[:] = sv; (@inferred getindex(mm, :))::MVector{4,Int} == MVector((4,5,6,7))) + + # SOneTo + @test begin + mm[SOneTo(length(mm))] = sv + (@inferred mm[SOneTo(length(mm))]) == sv + end + + # IdentityUnitRange{<:SOneTo} + if isdefined(Base, :IdentityUnitRange) + @test begin + mm[Base.IdentityUnitRange(SOneTo(length(mm)))] = sv + (@inferred mm[Base.IdentityUnitRange(SOneTo(length(mm)))]) == sv + end + end # SMatrix mm = MMatrix{2,2,Int}(undef) @@ -96,6 +137,12 @@ using StaticArrays, Test @test v[2,1] == 2 @test_throws BoundsError v[1,2] @test_throws BoundsError v[3,1] + + # SOneTo + @test (@inferred v[axes(v,1), SOneTo(1)]) === SMatrix{2,1}(v) + @test v[axes(v,1), SOneTo(1)] == v[Base.OneTo(length(v)), Base.OneTo(1)] + @test (@inferred v[axes(v,1), 1, SOneTo(1)]) === SMatrix{2,1}(v) + @test v[axes(v,1), 1, SOneTo(1)] == v[Base.OneTo(length(v)), 1, Base.OneTo(1)] end @testset "2D getindex() on SMatrix" begin @@ -122,6 +169,13 @@ using StaticArrays, Test # SOneTo @testinf sm[SOneTo(1),:] === @SMatrix [1 3] @testinf sm[:,SOneTo(1)] === @SMatrix [1;2] + + # IdentityUnitRange{<:SOneTo} + if isdefined(Base, :IdentityUnitRange) + @test (@inferred sm[Base.IdentityUnitRange(axes(sm, 1)), :]) === sm + @test (@inferred sm[Base.IdentityUnitRange(axes(sm, 1)), SOneTo(1)]) === SMatrix{2,1}(sm[:,1]) + @test (@inferred sm[Base.IdentityUnitRange(axes(sm, 1)), SVector{1}(SOneTo(1))]) === SMatrix{2,1}(sm[:,1]) + end end @testset "2D getindex()/setindex! on MMatrix" begin @@ -143,6 +197,18 @@ using StaticArrays, Test @test (mm = MMatrix{2,2,Int}(undef); mm[SOneTo(1),:] = sm[SOneTo(1),:]; (@inferred getindex(mm, SOneTo(1), :))::MMatrix == @MMatrix [1 3]) @test (mm = MMatrix{2,2,Int}(undef); mm[:,SOneTo(1)] = sm[:,SOneTo(1)]; (@inferred getindex(mm, :, SOneTo(1)))::MMatrix == @MMatrix [1;2]) + # IdentityUnitRange{<:SOneTo} + if isdefined(Base, :IdentityUnitRange) + @test begin + mm = MMatrix{2,2,Int}(undef); + mm[map(Base.IdentityUnitRange, axes(mm))...] = sm + (@inferred mm[map(Base.IdentityUnitRange, axes(mm))...]) == mm + (@inferred mm[Base.IdentityUnitRange(axes(mm,1)), :]) == mm + (@inferred mm[Base.IdentityUnitRange(axes(mm,1)), axes(mm,2)]) == mm + (@inferred mm[Base.IdentityUnitRange(axes(mm,1)), SVector{2}(axes(mm,2))]) == mm + end + end + # #866 @test_throws DimensionMismatch setindex!(MMatrix(SA[1 2; 3 4]), SA[3,4], 1, SA[1,2,3]) @test_throws DimensionMismatch setindex!(MMatrix(SA[1 2; 3 4]), [3,4], 1, SA[1,2,3]) @@ -189,6 +255,29 @@ using StaticArrays, Test @test (@inferred getindex(a, SVector(1,2), 1, 1, 1)) == [24,48] end + @testset "indexing with reshape for SMatrix/MMatrix" begin + sm = @SMatrix [1 3; 2 4] + mm = @MMatrix [1 3; 2 4] + for m in Any[sm, mm, view(sm, :, :), view(mm, :, :)] + sa = @inferred m[:, SOneTo(1), 1, SOneTo(1)] + a = m[:, Base.OneTo(1), 1, Base.OneTo(1)] + @test sa == a + @test sa == SArray{Tuple{2,1,1}}(a) + if m isa SArray + @test sa === SArray{Tuple{2,1,1}}(a) + end + + if isdefined(Base, :IdentityUnitRange) + sa = @inferred m[:, Base.IdentityUnitRange(SOneTo(1)), 1, SOneTo(1)] + @test sa == a + @test sa == SArray{Tuple{2,1,1}}(a) + if m isa SArray + @test sa === SArray{Tuple{2,1,1}}(a) + end + end + end + end + @testset "Indexing with empty vectors" begin a = [1.0 2.0; 3.0 4.0] @test a[SVector{0,Int}()] == SVector{0,Float64}(())