Skip to content

Commit 84e0f54

Browse files
authored
Sized AbstractArray (#783)
* SizedArray for AbstractArray pt 1 * tests for SizedArray of a view * A few tests for SizedArray * SizedArray bugfixes * disabling one tests that fails on nightly * SizedVector and SizedMatric constant changes * one more test for SizedArray view * views of SizedArray and MArray * view of MArray moved to MArray.jl * view(x, :) fixed * fixed view test * vec and parent for SizedArray * fixing the change of parent for SizedArray
1 parent c7f01b6 commit 84e0f54

File tree

6 files changed

+305
-71
lines changed

6 files changed

+305
-71
lines changed

src/MArray.jl

+9
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,12 @@ end
269269
function promote_rule(::Type{<:MArray{S,T,N,L}}, ::Type{<:MArray{S,U,N,L}}) where {S,T,U,N,L}
270270
MArray{S,promote_type(T,U),N,L}
271271
end
272+
273+
function Base.view(
274+
a::MArray{S},
275+
indices::Union{Integer, Colon, StaticVector, Base.Slice, SOneTo}...,
276+
) where {S}
277+
new_size = new_out_size(S, indices...)
278+
view_from_invoke = invoke(view, Tuple{AbstractArray, typeof(indices).parameters...}, a, indices...)
279+
return SizedArray{new_size}(view_from_invoke)
280+
end

src/SizedArray.jl

+171-45
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
SizedArray{Tuple{dims...}}(array)
44
5-
Wraps an `Array` with a static size, so to take advantage of the (faster)
5+
Wraps an `AbstractArray` with a static size, so to take advantage of the (faster)
66
methods defined by the static array package. The size is checked once upon
77
construction to determine if the number of elements (`length`) match, but the
88
array may be reshaped.
@@ -11,37 +11,48 @@ The aliases `SizedVector{N}` and `SizedMatrix{N,M}` are provided as more
1111
convenient names for one and two dimensional `SizedArray`s. For example, to
1212
wrap a 2x3 array `a` in a `SizedArray`, use `SizedMatrix{2,3}(a)`.
1313
"""
14-
struct SizedArray{S <: Tuple, T, N, M} <: StaticArray{S, T, N}
15-
data::Array{T, M}
14+
struct SizedArray{S<:Tuple,T,N,M,TData<:AbstractArray{T,M}} <: StaticArray{S,T,N}
15+
data::TData
1616

17-
function SizedArray{S, T, N, M}(a::Array) where {S, T, N, M}
18-
if length(a) != tuple_prod(S)
17+
function SizedArray{S,T,N,M,TData}(a::TData) where {S,T,N,M,TData<:AbstractArray{T,M}}
18+
if size(a) != size_to_tuple(S) && size(a) != (tuple_prod(S),)
1919
throw(DimensionMismatch("Dimensions $(size(a)) don't match static size $S"))
2020
end
21-
if size(a) != size_to_tuple(S)
22-
Base.depwarn("Construction of `SizedArray` with an `Array` of a different
23-
size is deprecated. If you need this functionality report it at
24-
https://github.com/JuliaArrays/StaticArrays.jl/pull/666 .
25-
Calling `sa = reshape(a::Array, s::Size)` will actually reshape
26-
array `a` in the future and converting `sa` back to `Array` will
27-
return an `Array` of shape `s`.", :SizedArray)
28-
end
29-
new{S,T,N,M}(a)
21+
return new{S,T,N,M,TData}(a)
3022
end
3123

32-
function SizedArray{S, T, N, M}(::UndefInitializer) where {S, T, N, M}
33-
new{S, T, N, M}(Array{T, M}(undef, size_to_tuple(S)...))
24+
function SizedArray{S,T,N,1,TData}(::UndefInitializer) where {S,T,N,TData<:AbstractArray{T,1}}
25+
return new{S,T,N,1,TData}(TData(undef, tuple_prod(S)))
26+
end
27+
function SizedArray{S,T,N,N,TData}(::UndefInitializer) where {S,T,N,TData<:AbstractArray{T,N}}
28+
return new{S,T,N,N,TData}(TData(undef, size_to_tuple(S)...))
3429
end
3530
end
3631

37-
@inline SizedArray{S,T,N}(a::Array{T,M}) where {S,T,N,M} = SizedArray{S,T,N,M}(a)
38-
@inline SizedArray{S,T}(a::Array{T,M}) where {S,T,M} = SizedArray{S,T,tuple_length(S),M}(a)
39-
@inline SizedArray{S}(a::Array{T,M}) where {S,T,M} = SizedArray{S,T,tuple_length(S),M}(a)
40-
41-
@inline SizedArray{S,T,N}(::UndefInitializer) where {S,T,N} = SizedArray{S,T,N,N}(undef)
42-
@inline SizedArray{S,T}(::UndefInitializer) where {S,T} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(undef)
43-
44-
@generated function SizedArray{S,T,N,M}(x::NTuple{L,Any}) where {S,T,N,M,L}
32+
@inline function SizedArray{S,T,N}(
33+
a::TData,
34+
) where {S,T,N,M,TData<:AbstractArray{T,M}}
35+
return SizedArray{S,T,N,M,TData}(a)
36+
end
37+
@inline function SizedArray{S,T}(a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
38+
return SizedArray{S,T,tuple_length(S),M,TData}(a)
39+
end
40+
@inline function SizedArray{S}(a::TData) where {S,T,M,TData<:AbstractArray{T,M}}
41+
return SizedArray{S,T,tuple_length(S),M,TData}(a)
42+
end
43+
function SizedArray{S,T,N,N}(::UndefInitializer) where {S,T,N}
44+
return SizedArray{S,T,N,N,Array{T,N}}(undef)
45+
end
46+
function SizedArray{S,T,N,1}(::UndefInitializer) where {S,T,N}
47+
return SizedArray{S,T,N,1,Vector{T}}(undef)
48+
end
49+
@inline function SizedArray{S,T,N}(::UndefInitializer) where {S,T,N}
50+
return SizedArray{S,T,N,N}(undef)
51+
end
52+
@inline function SizedArray{S,T}(::UndefInitializer) where {S,T}
53+
return SizedArray{S,T,tuple_length(S)}(undef)
54+
end
55+
@generated function (::Type{SizedArray{S,T,N,M,TData}})(x::NTuple{L,Any}) where {S,T,N,M,TData<:AbstractArray{T,M},L}
4556
if L != tuple_prod(S)
4657
error("Dimension mismatch")
4758
end
@@ -53,43 +64,158 @@ end
5364
return a
5465
end
5566
end
56-
57-
@inline SizedArray{S,T,N}(x::Tuple) where {S,T,N} = SizedArray{S,T,N,N}(x)
58-
@inline SizedArray{S,T}(x::Tuple) where {S,T} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(x)
59-
@inline SizedArray{S}(x::NTuple{L,T}) where {S,T,L} = SizedArray{S,T,tuple_length(S),tuple_length(S)}(x)
67+
@inline function SizedArray{S,T,N,M}(x::Tuple) where {S,T,N,M}
68+
return SizedArray{S,T,N,M,Array{T,M}}(x)
69+
end
70+
@inline function SizedArray{S,T,N}(x::Tuple) where {S,T,N}
71+
return SizedArray{S,T,N,N,Array{T,N}}(x)
72+
end
73+
@inline function SizedArray{S,T}(x::Tuple) where {S,T}
74+
return SizedArray{S,T,tuple_length(S)}(x)
75+
end
76+
@inline function SizedArray{S}(x::NTuple{L,T}) where {S,T,L}
77+
return SizedArray{S,T}(x)
78+
end
6079

6180
# Overide some problematic default behaviour
6281
@inline convert(::Type{SA}, sa::SizedArray) where {SA<:SizedArray} = SA(sa.data)
6382
@inline convert(::Type{SA}, sa::SA) where {SA<:SizedArray} = sa
6483

6584
# Back to Array (unfortunately need both convert and construct to overide other methods)
66-
@inline Array(sa::SizedArray) = Array(sa.data)
67-
@inline Array{T}(sa::SizedArray{S,T}) where {T,S} = Array{T}(sa.data)
68-
@inline Array{T,N}(sa::SizedArray{S,T,N}) where {T,S,N} = Array{T,N}(sa.data)
85+
@inline function Base.Array(sa::SizedArray{S}) where {S}
86+
return Array(reshape(sa.data, size_to_tuple(S)))
87+
end
88+
@inline function Base.Array{T}(sa::SizedArray{S,T}) where {T,S}
89+
return Array(reshape(sa.data, size_to_tuple(S)))
90+
end
91+
@inline function Base.Array{T,N}(sa::SizedArray{S,T,N}) where {T,S,N}
92+
return Array(reshape(sa.data, size_to_tuple(S)))
93+
end
6994

70-
@inline convert(::Type{Array}, sa::SizedArray) = sa.data
71-
@inline convert(::Type{Array{T}}, sa::SizedArray{S,T}) where {T,S} = sa.data
72-
@inline convert(::Type{Array{T,N}}, sa::SizedArray{S,T,N}) where {T,S,N} = sa.data
95+
@inline function convert(::Type{Array}, sa::SizedArray{S}) where {S}
96+
return Array(reshape(sa.data, size_to_tuple(S)))
97+
end
98+
@inline function convert(::Type{Array}, sa::SizedArray{S,T,N,M,Array{T,M}}) where {S,T,N,M}
99+
return sa.data
100+
end
101+
@inline function convert(::Type{Array{T}}, sa::SizedArray{S,T}) where {T,S}
102+
return Array(reshape(sa.data, size_to_tuple(S)))
103+
end
104+
@inline function convert(::Type{Array{T}}, sa::SizedArray{S,T,N,M,Array{T,M}}) where {S,T,N,M}
105+
return sa.data
106+
end
107+
@inline function convert(
108+
::Type{Array{T,N}},
109+
sa::SizedArray{S,T,N},
110+
) where {T,S,N}
111+
return Array(reshape(sa.data, size_to_tuple(S)))
112+
end
113+
@inline function convert(::Type{Array{T,N}}, sa::SizedArray{S,T,N,N,Array{T,N}}) where {S,T,N}
114+
return sa.data
115+
end
73116

74117
@propagate_inbounds getindex(a::SizedArray, i::Int) = getindex(a.data, i)
75118
@propagate_inbounds setindex!(a::SizedArray, v, i::Int) = setindex!(a.data, v, i)
76119

77-
SizedVector{S,T,M} = SizedArray{Tuple{S},T,1,M}
78-
@inline SizedVector{S}(a::Array{T,M}) where {S,T,M} = SizedArray{Tuple{S},T,1,M}(a)
79-
@inline SizedVector{S}(x::NTuple{L,T}) where {S,T,L} = SizedArray{Tuple{S},T,1,1}(x)
120+
Base.parent(sa::SizedArray) = sa.data
80121

81-
SizedMatrix{S1,S2,T,M} = SizedArray{Tuple{S1,S2},T,2,M}
82-
@inline SizedMatrix{S1,S2}(a::Array{T,M}) where {S1,S2,T,M} = SizedArray{Tuple{S1,S2},T,2,M}(a)
83-
@inline SizedMatrix{S1,S2}(x::NTuple{L,T}) where {S1,S2,T,L} = SizedArray{Tuple{S1,S2},T,2,2}(x)
122+
const SizedVector{S,T} = SizedArray{Tuple{S},T,1,1}
123+
124+
@inline function SizedVector{S}(a::TData) where {S,T,TData<:AbstractVector{T}}
125+
return SizedArray{Tuple{S},T,1,1,TData}(a)
126+
end
127+
@inline function SizedVector(x::NTuple{S,T}) where {S,T}
128+
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
129+
end
130+
@inline function SizedVector{S}(x::NTuple{S,T}) where {S,T}
131+
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
132+
end
133+
@inline function SizedVector{S,T}(x::NTuple{S}) where {S,T}
134+
return SizedArray{Tuple{S},T,1,1,Vector{T}}(x)
135+
end
136+
# disambiguation
137+
@inline function SizedVector{S}(a::StaticVector{S,T}) where {S,T}
138+
return SizedVector{S,T}(a.data)
139+
end
140+
141+
const SizedMatrix{S1,S2,T} = SizedArray{Tuple{S1,S2},T,2}
142+
143+
@inline function SizedMatrix{S1,S2}(
144+
a::TData,
145+
) where {S1,S2,T,M,TData<:AbstractArray{T,M}}
146+
return SizedArray{Tuple{S1,S2},T,2,M,TData}(a)
147+
end
148+
@inline function SizedMatrix{S1,S2}(x::NTuple{L,T}) where {S1,S2,T,L}
149+
return SizedArray{Tuple{S1,S2},T,2,2,Matrix{T}}(x)
150+
end
151+
@inline function SizedMatrix{S1,S2,T}(x::NTuple{L}) where {S1,S2,T,L}
152+
return SizedArray{Tuple{S1,S2},T,2,2,Matrix{T}}(x)
153+
end
154+
# disambiguation
155+
@inline function SizedMatrix{S1,S2}(a::StaticMatrix{S1,S2,T}) where {S1,S2,T}
156+
return SizedMatrix{S1,S2,T}(a.data)
157+
end
84158

85159
Base.dataids(sa::SizedArray) = Base.dataids(sa.data)
86160

87-
function (::Size{S})(a::Array) where {S}
88-
Base.depwarn("`Size{S}(a::Array)` is deprecated, use `SizedVector{N}(a)`, `SizedMatrix{N,M}(a)` or `SizedArray{Tuple{S}}(a)` instead", :Size)
89-
SizedArray{Tuple{S...}}(a)
161+
function promote_rule(
162+
::Type{SizedArray{S,T,N,M,TDataA}},
163+
::Type{SizedArray{S,U,N,M,TDataB}},
164+
) where {S,T,U,N,M,TDataA,TDataB}
165+
TU = promote_type(T, U)
166+
return SizedArray{S, TU, N, M, promote_type(TDataA, TDataB)}
90167
end
91168

169+
function promote_rule(
170+
::Type{SizedArray{S,T,N,M}},
171+
::Type{SizedArray{S,U,N,M}},
172+
) where {S,T,U,N,M,}
173+
TU = promote_type(T, U)
174+
return SizedArray{S, TU, N, M}
175+
end
176+
177+
function promote_rule(
178+
::Type{SizedArray{S,T,N}},
179+
::Type{SizedArray{S,U,N}},
180+
) where {S,T,U,N}
181+
TU = promote_type(T, U)
182+
return SizedArray{S, TU, N}
183+
end
184+
185+
186+
### Code that makes views of statically sized arrays also statically sized (where possible)
187+
188+
@generated function new_out_size(::Type{Size}, inds...) where Size
189+
os = []
190+
map(Size.parameters, inds) do s, i
191+
if i <: Integer
192+
# dimension is fixed
193+
elseif i <: StaticVector
194+
push!(os, i.parameters[1].parameters[1])
195+
elseif i == Colon || i <: Base.Slice
196+
push!(os, s)
197+
elseif i <: SOneTo
198+
push!(os, i.parameters[1])
199+
else
200+
error("Unknown index type: $i")
201+
end
202+
end
203+
return Tuple{os...}
204+
end
205+
206+
@generated function new_out_size(::Type{Size}, ::Colon) where Size
207+
prod_size = tuple_prod(Size)
208+
return Tuple{prod_size}
209+
end
210+
211+
function Base.view(
212+
a::SizedArray{S},
213+
indices::Union{Integer, Colon, StaticVector, Base.Slice, SOneTo}...,
214+
) where {S}
215+
new_size = new_out_size(S, indices...)
216+
return SizedArray{new_size}(view(a.data, indices...))
217+
end
92218

93-
function promote_rule(::Type{<:SizedArray{S,T,N,M}}, ::Type{<:SizedArray{S,U,N,M}}) where {S,T,U,N,M}
94-
SizedArray{S,promote_type(T,U),N,M}
219+
function Base.vec(a::SizedArray{S}) where {S}
220+
return SizedVector{tuple_prod(S)}(vec(a.data))
95221
end

src/matrix_multiply_add.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,18 @@ Base.transpose(::TSize{S,T}) where {S,T} = TSize{reverse(S),!T}()
4040

4141
# Get the parent of transposed arrays, or the array itself if it has no parent
4242
# QUESTION: maybe call this something else?
43-
Base.parent(A::Union{<:Transpose{<:Any,<:StaticArray}, <:Adjoint{<:Any,<:StaticArray}}) = A.parent
44-
Base.parent(A::StaticArray) = A
43+
mul_parent(A) = parent(A)
44+
mul_parent(A::StaticArray) = A
4545

4646
# 5-argument matrix multiplication
4747
# To avoid allocations, strip away Transpose type and store tranpose info in Size
4848
@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike, B::StaticVecOrMatLike,
49-
α::Real, β::Real) = _mul!(TSize(dest), parent(dest), TSize(A), TSize(B), parent(A), parent(B),
49+
α::Real, β::Real) = _mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B),
5050
AlphaBeta(α,β))
5151

5252
@inline LinearAlgebra.mul!(dest::StaticVecOrMatLike, A::StaticVecOrMatLike{T},
5353
B::StaticVecOrMatLike{T}) where T =
54-
_mul!(TSize(dest), parent(dest), TSize(A), TSize(B), parent(A), parent(B), NoMulAdd{T}())
54+
_mul!(TSize(dest), mul_parent(dest), TSize(A), TSize(B), mul_parent(A), mul_parent(B), NoMulAdd{T}())
5555

5656

5757
"Calculate the product of the dimensions being multiplied. Useful as a heuristic for unrolling."

0 commit comments

Comments
 (0)