From c8e2fa78b3b8290017812efe6be8d23f2a6d62cd Mon Sep 17 00:00:00 2001 From: Twan Koolen Date: Mon, 10 Dec 2018 18:47:35 -0500 Subject: [PATCH 1/2] Introduce StaticallySizedArray Now that we have wrapper types in Base that can be used in conjunction with StaticArrays, we should discern between actual `StaticArray`s and non-`StaticArray` subtypes for which we still know a static `Size`. To this effect, this commit adds the type aliases `StaticallySizedMatrix`, `StaticallySizedVecOrMat`, and `StaticallySizedArray`, and uses them to widen various type signatures. Fixes #561. --- src/StaticArrays.jl | 13 +++++++++++++ src/abstractarray.jl | 12 ++++++------ src/convert.jl | 4 ++-- src/linalg.jl | 22 ++++++++-------------- src/traits.jl | 7 +++++++ src/util.jl | 4 +--- test/core.jl | 8 ++++++++ test/linalg.jl | 17 +++++++++++++++++ 8 files changed, 62 insertions(+), 25 deletions(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index ad852d1c..27885fee 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -78,6 +78,19 @@ const StaticVector{N, T} = StaticArray{Tuple{N}, T, 1} const StaticMatrix{N, M, T} = StaticArray{Tuple{N, M}, T, 2} const StaticVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticMatrix{<:Any, <:Any, T}} +# Being a member of StaticallySizedMatrix, StaticallySizedVecOrMat, or StaticallySizedArray implies that Size(A) +# returns a static Size instance. The converse may not be true. +const StaticallySizedMatrix{T} = Union{ + StaticMatrix{<:Any, <:Any, T}, + Transpose{T, <:StaticVecOrMat{T}}, + Adjoint{T, <:StaticVecOrMat{T}}, + Symmetric{T, <:StaticMatrix{T}}, + Hermitian{T, <:StaticMatrix{T}}, + Diagonal{T, <:StaticVector{<:Any, T}} +} +const StaticallySizedVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticallySizedMatrix{T}} +const StaticallySizedArray{T} = Union{StaticallySizedVecOrMat{T}, StaticArray{<:Any, T}} + const AbstractScalar{T} = AbstractArray{T, 0} # not exported, but useful none-the-less const StaticArrayNoEltype{S, N, T} = StaticArray{S, T, N} diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 733ca011..2e628d39 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -1,13 +1,13 @@ -length(a::SA) where {SA <: StaticArray} = prod(Size(SA)) -length(a::Type{SA}) where {SA <: StaticArray} = prod(Size(SA)) +length(a::SA) where {SA <: StaticallySizedArray} = length(SA) +length(a::Type{SA}) where {SA <: StaticallySizedArray} = prod(Size(SA)) -@pure size(::Type{<:StaticArray{S}}) where S = tuple(S.parameters...) -@inline function size(t::Type{<:StaticArray}, d::Int) +@pure size(::Type{SA}) where {SA <: StaticallySizedArray} = get(Size(SA)) +@inline function size(t::Type{<:StaticallySizedArray}, d::Int) S = size(t) d > length(S) ? 1 : S[d] end -@inline size(a::StaticArray) = size(typeof(a)) -@inline size(a::StaticArray, d::Int) = size(typeof(a), d) +@inline size(a::StaticallySizedArray) = size(typeof(a)) +@inline size(a::StaticallySizedArray, d::Int) = size(typeof(a), d) Base.axes(s::StaticArray) = _axes(Size(s)) @pure function _axes(::Size{sizes}) where {sizes} diff --git a/src/convert.jl b/src/convert.jl index f30570cb..6995c84d 100644 --- a/src/convert.jl +++ b/src/convert.jl @@ -25,8 +25,8 @@ end return SA(unroll_tuple(a, Length(SA))) end -length_val(a::T) where {T <: StaticArray} = length_val(Size(T)) -length_val(a::Type{T}) where {T<:StaticArray} = length_val(Size(T)) +length_val(a::T) where {T <: StaticallySizedArray} = length_val(Size(T)) +length_val(a::Type{T}) where {T<:StaticallySizedArray} = length_val(Size(T)) @generated function unroll_tuple(a::AbstractArray, ::Length{L}) where {L} exprs = [:(a[$j]) for j = 1:L] diff --git a/src/linalg.jl b/src/linalg.jl index d4645012..c915eb3c 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -88,11 +88,11 @@ end end end -@inline vcat(a::StaticVecOrMat) = a -@inline vcat(a::StaticVecOrMat, b::StaticVecOrMat) = _vcat(Size(a), Size(b), a, b) -@inline vcat(a::StaticVecOrMat, b::StaticVecOrMat, c::StaticVecOrMat...) = vcat(vcat(a,b), vcat(c...)) +@inline vcat(a::StaticallySizedVecOrMat) = a +@inline vcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) = _vcat(Size(a), Size(b), a, b) +@inline vcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat, c::StaticallySizedVecOrMat...) = vcat(vcat(a,b), vcat(c...)) -@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMat, b::StaticVecOrMat) where {Sa, Sb} +@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) where {Sa, Sb} if Size(Sa)[2] != Size(Sb)[2] throw(DimensionMismatch("Tried to vcat arrays of size $Sa and $Sb")) end @@ -116,11 +116,11 @@ end end @inline hcat(a::StaticVector) = similar_type(a, Size(Size(a)[1],1))(a) -@inline hcat(a::StaticMatrix) = a -@inline hcat(a::StaticVecOrMat, b::StaticVecOrMat) = _hcat(Size(a), Size(b), a, b) -@inline hcat(a::StaticVecOrMat, b::StaticVecOrMat, c::StaticVecOrMat...) = hcat(hcat(a,b), hcat(c...)) +@inline hcat(a::StaticallySizedMatrix) = a +@inline hcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) = _hcat(Size(a), Size(b), a, b) +@inline hcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat, c::StaticallySizedVecOrMat...) = hcat(hcat(a,b), hcat(c...)) -@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMat, b::StaticVecOrMat) where {Sa, Sb} +@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) where {Sa, Sb} if Sa[1] != Sb[1] throw(DimensionMismatch("Tried to hcat arrays of size $Sa and $Sb")) end @@ -491,12 +491,6 @@ end end end - -@inline Size(::Type{<:Adjoint{T, SA}}) where {T, SA <: StaticVecOrMat} = Size(Size(SA)[2], Size(SA)[1]) -@inline Size(::Type{<:Transpose{T, SA}}) where {T, SA <: StaticVecOrMat} = Size(Size(SA)[2], Size(SA)[1]) -@inline Size(::Type{Symmetric{T, SA}}) where {T, SA<:StaticArray} = Size(SA) -@inline Size(::Type{Hermitian{T, SA}}) where {T, SA<:StaticArray} = Size(SA) - # some micro-optimizations (TODO check these make sense for v0.6+) @inline LinearAlgebra.checksquare(::SM) where {SM<:StaticMatrix} = _checksquare(Size(SM)) @inline LinearAlgebra.checksquare(::Type{SM}) where {SM<:StaticMatrix} = _checksquare(Size(SM)) diff --git a/src/traits.jl b/src/traits.jl index c6308cc9..4aada665 100644 --- a/src/traits.jl +++ b/src/traits.jl @@ -86,6 +86,13 @@ end Size(a::T) where {T<:AbstractArray} = Size(T) Size(::Type{SA}) where {SA <: StaticArray} = missing_size_error(SA) Size(::Type{SA}) where {SA <: StaticArray{S}} where {S<:Tuple} = @isdefined(S) ? Size(S) : missing_size_error(SA) + +Size(::Type{Adjoint{T, A}}) where {T, A <: AbstractVecOrMat{T}} = Size(Size(A)[2], Size(A)[1]) +Size(::Type{Transpose{T, A}}) where {T, A <: AbstractVecOrMat{T}} = Size(Size(A)[2], Size(A)[1]) +Size(::Type{Symmetric{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A) +Size(::Type{Hermitian{T, A}}) where {T, A <: AbstractMatrix{T}} = Size(A) +Size(::Type{Diagonal{T, A}}) where {T, A <: AbstractVector{T}} = Size(Size(A)[1], Size(A)[1]) + @pure Size(::Type{<:AbstractArray{<:Any, N}}) where {N} = Size(ntuple(_ -> Dynamic(), N)) struct Length{L} diff --git a/src/util.jl b/src/util.jl index ad86fe34..1e076154 100644 --- a/src/util.jl +++ b/src/util.jl @@ -90,7 +90,5 @@ TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a) # certain algorithms where the number of elements of the output is a lot larger # than the input. # """ -@inline drop_sdims(a::StaticArray) = TrivialView(a) -@inline drop_sdims(a::Transpose{<:Number, <:StaticArray}) = TrivialView(a) -@inline drop_sdims(a::Adjoint{<:Number, <:StaticArray}) = TrivialView(a) +@inline drop_sdims(a::StaticallySizedArray) = TrivialView(a) @inline drop_sdims(a) = a diff --git a/test/core.jl b/test/core.jl index 8d795fbe..5b48ff46 100644 --- a/test/core.jl +++ b/test/core.jl @@ -149,6 +149,14 @@ @test StaticArrays.check_length(2) == nothing @test StaticArrays.check_length(StaticArrays.Dynamic()) == nothing + @testset "Size" begin + @test Size(zero(SMatrix{2, 3})) == Size(2, 3) + @test Size(Transpose(zero(SMatrix{2, 3}))) == Size(3, 2) + @test Size(Adjoint(zero(SMatrix{2, 3}))) == Size(3, 2) + @test Size(Diagonal(SVector(1, 2, 3))) == Size(3, 3) + @test Size(Transpose(Diagonal(SVector(1, 2, 3)))) == Size(3, 3) + end + @testset "dimmatch" begin @test StaticArrays.dimmatch(3, 3) @test StaticArrays.dimmatch(3, StaticArrays.Dynamic()) diff --git a/test/linalg.jl b/test/linalg.jl index 477268ff..dd16d175 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -194,6 +194,23 @@ using StaticArrays, Test, LinearAlgebra @test allocs == 0 end end + + # issue #561 + let A = Diagonal(SVector(1, 2)), B = @SMatrix [3 4; 5 6] + @test @inferred(hcat(A, B)) === SMatrix{2, 4}([Matrix(A) Matrix(B)]) + end + + let A = Transpose(@SMatrix [1 2; 3 4]), B = Adjoint(@SMatrix [5 6; 7 8]) + @test @inferred(hcat(A, B)) === SMatrix{2, 4}([Matrix(A) Matrix(B)]) + end + + let A = Diagonal(SVector(1, 2)), B = @SMatrix [3 4; 5 6] + @test @inferred(vcat(A, B)) === SMatrix{4, 2}([Matrix(A); Matrix(B)]) + end + + let A = Transpose(@SMatrix [1 2; 3 4]), B = Adjoint(@SMatrix [5 6; 7 8]) + @test @inferred(vcat(A, B)) === SMatrix{4, 2}([Matrix(A); Matrix(B)]) + end end @testset "normalization" begin From 7a4bc2efacb6e44247927c685d785d49f01f82b0 Mon Sep 17 00:00:00 2001 From: Twan Koolen Date: Tue, 11 Dec 2018 03:17:40 -0500 Subject: [PATCH 2/2] Address comments (rename, ref. SizedArray). --- src/StaticArrays.jl | 12 +++++++----- src/abstractarray.jl | 12 ++++++------ src/convert.jl | 4 ++-- src/linalg.jl | 16 ++++++++-------- src/util.jl | 2 +- 5 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 27885fee..59824cf0 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -78,9 +78,11 @@ const StaticVector{N, T} = StaticArray{Tuple{N}, T, 1} const StaticMatrix{N, M, T} = StaticArray{Tuple{N, M}, T, 2} const StaticVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticMatrix{<:Any, <:Any, T}} -# Being a member of StaticallySizedMatrix, StaticallySizedVecOrMat, or StaticallySizedArray implies that Size(A) -# returns a static Size instance. The converse may not be true. -const StaticallySizedMatrix{T} = Union{ +# Being a member of StaticMatrixLike, StaticVecOrMatLike, or StaticArrayLike implies that Size(A) +# returns a static Size instance (none of the dimensions are Dynamic). The converse may not be true. +# These are akin to aliases like StridedArray and in similarly bad taste, but the current approach +# in Base necessitates their existence. +const StaticMatrixLike{T} = Union{ StaticMatrix{<:Any, <:Any, T}, Transpose{T, <:StaticVecOrMat{T}}, Adjoint{T, <:StaticVecOrMat{T}}, @@ -88,8 +90,8 @@ const StaticallySizedMatrix{T} = Union{ Hermitian{T, <:StaticMatrix{T}}, Diagonal{T, <:StaticVector{<:Any, T}} } -const StaticallySizedVecOrMat{T} = Union{StaticVector{<:Any, T}, StaticallySizedMatrix{T}} -const StaticallySizedArray{T} = Union{StaticallySizedVecOrMat{T}, StaticArray{<:Any, T}} +const StaticVecOrMatLike{T} = Union{StaticVector{<:Any, T}, StaticMatrixLike{T}} +const StaticArrayLike{T} = Union{StaticVecOrMatLike{T}, StaticArray{<:Any, T}} const AbstractScalar{T} = AbstractArray{T, 0} # not exported, but useful none-the-less const StaticArrayNoEltype{S, N, T} = StaticArray{S, T, N} diff --git a/src/abstractarray.jl b/src/abstractarray.jl index 2e628d39..4c572e88 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -1,13 +1,13 @@ -length(a::SA) where {SA <: StaticallySizedArray} = length(SA) -length(a::Type{SA}) where {SA <: StaticallySizedArray} = prod(Size(SA)) +length(a::SA) where {SA <: StaticArrayLike} = length(SA) +length(a::Type{SA}) where {SA <: StaticArrayLike} = prod(Size(SA)) -@pure size(::Type{SA}) where {SA <: StaticallySizedArray} = get(Size(SA)) -@inline function size(t::Type{<:StaticallySizedArray}, d::Int) +@pure size(::Type{SA}) where {SA <: StaticArrayLike} = get(Size(SA)) +@inline function size(t::Type{<:StaticArrayLike}, d::Int) S = size(t) d > length(S) ? 1 : S[d] end -@inline size(a::StaticallySizedArray) = size(typeof(a)) -@inline size(a::StaticallySizedArray, d::Int) = size(typeof(a), d) +@inline size(a::StaticArrayLike) = size(typeof(a)) +@inline size(a::StaticArrayLike, d::Int) = size(typeof(a), d) Base.axes(s::StaticArray) = _axes(Size(s)) @pure function _axes(::Size{sizes}) where {sizes} diff --git a/src/convert.jl b/src/convert.jl index 6995c84d..710f9693 100644 --- a/src/convert.jl +++ b/src/convert.jl @@ -25,8 +25,8 @@ end return SA(unroll_tuple(a, Length(SA))) end -length_val(a::T) where {T <: StaticallySizedArray} = length_val(Size(T)) -length_val(a::Type{T}) where {T<:StaticallySizedArray} = length_val(Size(T)) +length_val(a::T) where {T <: StaticArrayLike} = length_val(Size(T)) +length_val(a::Type{T}) where {T<:StaticArrayLike} = length_val(Size(T)) @generated function unroll_tuple(a::AbstractArray, ::Length{L}) where {L} exprs = [:(a[$j]) for j = 1:L] diff --git a/src/linalg.jl b/src/linalg.jl index c915eb3c..4887c0e2 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -88,11 +88,11 @@ end end end -@inline vcat(a::StaticallySizedVecOrMat) = a -@inline vcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) = _vcat(Size(a), Size(b), a, b) -@inline vcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat, c::StaticallySizedVecOrMat...) = vcat(vcat(a,b), vcat(c...)) +@inline vcat(a::StaticVecOrMatLike) = a +@inline vcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike) = _vcat(Size(a), Size(b), a, b) +@inline vcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike, c::StaticVecOrMatLike...) = vcat(vcat(a,b), vcat(c...)) -@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) where {Sa, Sb} +@generated function _vcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMatLike, b::StaticVecOrMatLike) where {Sa, Sb} if Size(Sa)[2] != Size(Sb)[2] throw(DimensionMismatch("Tried to vcat arrays of size $Sa and $Sb")) end @@ -116,11 +116,11 @@ end end @inline hcat(a::StaticVector) = similar_type(a, Size(Size(a)[1],1))(a) -@inline hcat(a::StaticallySizedMatrix) = a -@inline hcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) = _hcat(Size(a), Size(b), a, b) -@inline hcat(a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat, c::StaticallySizedVecOrMat...) = hcat(hcat(a,b), hcat(c...)) +@inline hcat(a::StaticMatrixLike) = a +@inline hcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike) = _hcat(Size(a), Size(b), a, b) +@inline hcat(a::StaticVecOrMatLike, b::StaticVecOrMatLike, c::StaticVecOrMatLike...) = hcat(hcat(a,b), hcat(c...)) -@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticallySizedVecOrMat, b::StaticallySizedVecOrMat) where {Sa, Sb} +@generated function _hcat(::Size{Sa}, ::Size{Sb}, a::StaticVecOrMatLike, b::StaticVecOrMatLike) where {Sa, Sb} if Sa[1] != Sb[1] throw(DimensionMismatch("Tried to hcat arrays of size $Sa and $Sb")) end diff --git a/src/util.jl b/src/util.jl index 1e076154..e46f44cb 100644 --- a/src/util.jl +++ b/src/util.jl @@ -90,5 +90,5 @@ TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a) # certain algorithms where the number of elements of the output is a lot larger # than the input. # """ -@inline drop_sdims(a::StaticallySizedArray) = TrivialView(a) +@inline drop_sdims(a::StaticArrayLike) = TrivialView(a) @inline drop_sdims(a) = a