Skip to content

Commit c8e2fa7

Browse files
committed
Introduce StaticallySizedArray
Now that we have wrapper types in Base that can be used in conjunction with StaticArrays, we should discern between actual `StaticArray`s and non-`StaticArray` subtypes for which we still know a static `Size`. To this effect, this commit adds the type aliases `StaticallySizedMatrix`, `StaticallySizedVecOrMat`, and `StaticallySizedArray`, and uses them to widen various type signatures. Fixes #561.
1 parent 3ccbd41 commit c8e2fa7

File tree

8 files changed

+62
-25
lines changed

8 files changed

+62
-25
lines changed

src/StaticArrays.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,19 @@ const StaticVector{N, T} = StaticArray{Tuple{N}, T, 1}
7878
const StaticMatrix{N, M, T} = StaticArray{Tuple{N, M}, T, 2}
7979
const StaticVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticMatrix{<:Any, <:Any, T}}
8080

81+
# Being a member of StaticallySizedMatrix, StaticallySizedVecOrMat, or StaticallySizedArray implies that Size(A)
82+
# returns a static Size instance. The converse may not be true.
83+
const StaticallySizedMatrix{T} = Union{
84+
StaticMatrix{<:Any, <:Any, T},
85+
Transpose{T, <:StaticVecOrMat{T}},
86+
Adjoint{T, <:StaticVecOrMat{T}},
87+
Symmetric{T, <:StaticMatrix{T}},
88+
Hermitian{T, <:StaticMatrix{T}},
89+
Diagonal{T, <:StaticVector{<:Any, T}}
90+
}
91+
const StaticallySizedVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticallySizedMatrix{T}}
92+
const StaticallySizedArray{T} = Union{StaticallySizedVecOrMat{T}, StaticArray{<:Any, T}}
93+
8194
const AbstractScalar{T} = AbstractArray{T, 0} # not exported, but useful none-the-less
8295
const StaticArrayNoEltype{S, N, T} = StaticArray{S, T, N}
8396

src/abstractarray.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
length(a::SA) where {SA <: StaticArray} = prod(Size(SA))
2-
length(a::Type{SA}) where {SA <: StaticArray} = prod(Size(SA))
1+
length(a::SA) where {SA <: StaticallySizedArray} = length(SA)
2+
length(a::Type{SA}) where {SA <: StaticallySizedArray} = prod(Size(SA))
33

4-
@pure size(::Type{<:StaticArray{S}}) where S = tuple(S.parameters...)
5-
@inline function size(t::Type{<:StaticArray}, d::Int)
4+
@pure size(::Type{SA}) where {SA <: StaticallySizedArray} = get(Size(SA))
5+
@inline function size(t::Type{<:StaticallySizedArray}, d::Int)
66
S = size(t)
77
d > length(S) ? 1 : S[d]
88
end
9-
@inline size(a::StaticArray) = size(typeof(a))
10-
@inline size(a::StaticArray, d::Int) = size(typeof(a), d)
9+
@inline size(a::StaticallySizedArray) = size(typeof(a))
10+
@inline size(a::StaticallySizedArray, d::Int) = size(typeof(a), d)
1111

1212
Base.axes(s::StaticArray) = _axes(Size(s))
1313
@pure function _axes(::Size{sizes}) where {sizes}

src/convert.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
return SA(unroll_tuple(a, Length(SA)))
2626
end
2727

28-
length_val(a::T) where {T <: StaticArray} = length_val(Size(T))
29-
length_val(a::Type{T}) where {T<:StaticArray} = length_val(Size(T))
28+
length_val(a::T) where {T <: StaticallySizedArray} = length_val(Size(T))
29+
length_val(a::Type{T}) where {T<:StaticallySizedArray} = length_val(Size(T))
3030

3131
@generated function unroll_tuple(a::AbstractArray, ::Length{L}) where {L}
3232
exprs = [:(a[$j]) for j = 1:L]

src/linalg.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ end
8888
end
8989
end
9090

91-
@inline vcat(a::StaticVecOrMat) = a
92-
@inline vcat(a::StaticVecOrMat, b::StaticVecOrMat) = _vcat(Size(a), Size(b), a, b)
93-
@inline vcat(a::StaticVecOrMat, b::StaticVecOrMat, c::StaticVecOrMat...) = vcat(vcat(a,b), vcat(c...))
91+
@inline vcat(a::StaticallySizedVecOrMat) = a
92+
@inline vcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) = _vcat(Size(a), Size(b), a, b)
93+
@inline vcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat, c::StaticallySizedVecOrMat...) = vcat(vcat(a,b), vcat(c...))
9494

95-
@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMat, b::StaticVecOrMat) where {Sa, Sb}
95+
@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) where {Sa, Sb}
9696
if Size(Sa)[2] != Size(Sb)[2]
9797
throw(DimensionMismatch("Tried to vcat arrays of size $Sa and $Sb"))
9898
end
@@ -116,11 +116,11 @@ end
116116
end
117117

118118
@inline hcat(a::StaticVector) = similar_type(a, Size(Size(a)[1],1))(a)
119-
@inline hcat(a::StaticMatrix) = a
120-
@inline hcat(a::StaticVecOrMat, b::StaticVecOrMat) = _hcat(Size(a), Size(b), a, b)
121-
@inline hcat(a::StaticVecOrMat, b::StaticVecOrMat, c::StaticVecOrMat...) = hcat(hcat(a,b), hcat(c...))
119+
@inline hcat(a::StaticallySizedMatrix) = a
120+
@inline hcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) = _hcat(Size(a), Size(b), a, b)
121+
@inline hcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat, c::StaticallySizedVecOrMat...) = hcat(hcat(a,b), hcat(c...))
122122

123-
@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMat, b::StaticVecOrMat) where {Sa, Sb}
123+
@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) where {Sa, Sb}
124124
if Sa[1] != Sb[1]
125125
throw(DimensionMismatch("Tried to hcat arrays of size $Sa and $Sb"))
126126
end
@@ -491,12 +491,6 @@ end
491491
end
492492
end
493493

494-
495-
@inline Size(::Type{<:Adjoint{T, SA}}) where {T, SA <: StaticVecOrMat} = Size(Size(SA)[2], Size(SA)[1])
496-
@inline Size(::Type{<:Transpose{T, SA}}) where {T, SA <: StaticVecOrMat} = Size(Size(SA)[2], Size(SA)[1])
497-
@inline Size(::Type{Symmetric{T, SA}}) where {T, SA<:StaticArray} = Size(SA)
498-
@inline Size(::Type{Hermitian{T, SA}}) where {T, SA<:StaticArray} = Size(SA)
499-
500494
# some micro-optimizations (TODO check these make sense for v0.6+)
501495
@inline LinearAlgebra.checksquare(::SM) where {SM<:StaticMatrix} = _checksquare(Size(SM))
502496
@inline LinearAlgebra.checksquare(::Type{SM}) where {SM<:StaticMatrix} = _checksquare(Size(SM))

src/traits.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ end
8686
Size(a::T) where {T<:AbstractArray} = Size(T)
8787
Size(::Type{SA}) where {SA <: StaticArray} = missing_size_error(SA)
8888
Size(::Type{SA}) where {SA <: StaticArray{S}} where {S<:Tuple} = @isdefined(S) ? Size(S) : missing_size_error(SA)
89+
90+
Size(::Type{Adjoint{T, A}}) where {T, A <: AbstractVecOrMat{T}} = Size(Size(A)[2], Size(A)[1])
91+
Size(::Type{Transpose{T, A}}) where {T, A <: AbstractVecOrMat{T}} = Size(Size(A)[2], Size(A)[1])
92+
Size(::Type{Symmetric{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A)
93+
Size(::Type{Hermitian{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A)
94+
Size(::Type{Diagonal{T, A}}) where {T, A <: AbstractVector{T}} = Size(Size(A)[1], Size(A)[1])
95+
8996
@pure Size(::Type{<:AbstractArray{<:Any, N}}) where {N} = Size(ntuple(_ -> Dynamic(), N))
9097

9198
struct Length{L}

src/util.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,5 @@ TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a)
9090
# certain algorithms where the number of elements of the output is a lot larger
9191
# than the input.
9292
# """
93-
@inline drop_sdims(a::StaticArray) = TrivialView(a)
94-
@inline drop_sdims(a::Transpose{<:Number, <:StaticArray}) = TrivialView(a)
95-
@inline drop_sdims(a::Adjoint{<:Number, <:StaticArray}) = TrivialView(a)
93+
@inline drop_sdims(a::StaticallySizedArray) = TrivialView(a)
9694
@inline drop_sdims(a) = a

test/core.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,14 @@
149149
@test StaticArrays.check_length(2) == nothing
150150
@test StaticArrays.check_length(StaticArrays.Dynamic()) == nothing
151151

152+
@testset "Size" begin
153+
@test Size(zero(SMatrix{2, 3})) == Size(2, 3)
154+
@test Size(Transpose(zero(SMatrix{2, 3}))) == Size(3, 2)
155+
@test Size(Adjoint(zero(SMatrix{2, 3}))) == Size(3, 2)
156+
@test Size(Diagonal(SVector(1, 2, 3))) == Size(3, 3)
157+
@test Size(Transpose(Diagonal(SVector(1, 2, 3)))) == Size(3, 3)
158+
end
159+
152160
@testset "dimmatch" begin
153161
@test StaticArrays.dimmatch(3, 3)
154162
@test StaticArrays.dimmatch(3, StaticArrays.Dynamic())

test/linalg.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,23 @@ using StaticArrays, Test, LinearAlgebra
194194
@test allocs == 0
195195
end
196196
end
197+
198+
# issue #561
199+
let A = Diagonal(SVector(1, 2)), B = @SMatrix [3 4; 5 6]
200+
@test @inferred(hcat(A, B)) === SMatrix{2, 4}([Matrix(A) Matrix(B)])
201+
end
202+
203+
let A = Transpose(@SMatrix [1 2; 3 4]), B = Adjoint(@SMatrix [5 6; 7 8])
204+
@test @inferred(hcat(A, B)) === SMatrix{2, 4}([Matrix(A) Matrix(B)])
205+
end
206+
207+
let A = Diagonal(SVector(1, 2)), B = @SMatrix [3 4; 5 6]
208+
@test @inferred(vcat(A, B)) === SMatrix{4, 2}([Matrix(A); Matrix(B)])
209+
end
210+
211+
let A = Transpose(@SMatrix [1 2; 3 4]), B = Adjoint(@SMatrix [5 6; 7 8])
212+
@test @inferred(vcat(A, B)) === SMatrix{4, 2}([Matrix(A); Matrix(B)])
213+
end
197214
end
198215

199216
@testset "normalization" begin

0 commit comments

Comments
 (0)