Skip to content

Commit c52ac6a

Browse files
authored
Define mapfoldl/foldl for static arrays (#750)
1 parent c808bdd commit c52ac6a

File tree

3 files changed

+88
-72
lines changed

3 files changed

+88
-72
lines changed

src/StaticArrays.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module StaticArrays
33
import Base: @_inline_meta, @_propagate_inbounds_meta, @_pure_meta, @propagate_inbounds, @pure
44

55
import Base: getindex, setindex!, size, similar, vec, show, length, convert, promote_op,
6-
promote_rule, map, map!, reduce, mapreduce, broadcast,
6+
promote_rule, map, map!, reduce, mapreduce, foldl, mapfoldl, broadcast,
77
broadcast!, conj, hcat, vcat, ones, zeros, one, reshape, fill, fill!, inv,
88
iszero, sum, prod, count, any, all, minimum, maximum, extrema,
99
copy, read, read!, write, reverse

src/mapreduce.jl

+78-71
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
"""
2+
_InitialValue
3+
4+
A singleton type for representing "universal" initial value (identity element).
5+
6+
The idea is that, given `op` for `mapfoldl`, virtually, we define an "extended"
7+
version of it by
8+
9+
op′(::_InitialValue, x) = x
10+
op′(acc, x) = op(acc, x)
11+
12+
This is just a conceptually useful model to have in mind and we don't actually
13+
define `op′` here (yet?). But see `Base.BottomRF` for how it might work in
14+
action.
15+
16+
(It is related to that you can always turn a semigroup without an identity into
17+
a monoid by "adjoining" an element that acts as the identity.)
18+
"""
19+
struct _InitialValue end
20+
121
@inline _first(a1, as...) = a1
222

323
################
@@ -86,28 +106,21 @@ end
86106
## mapreduce ##
87107
###############
88108

89-
@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:,kw...)
90-
_mapreduce(f, op, dims, kw.data, same_size(a, b...), a, b...)
109+
@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:, init = _InitialValue())
110+
_mapreduce(f, op, dims, init, same_size(a, b...), a, b...)
91111
end
92112

93-
@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{()},
94-
::Size{S}, a::StaticArray...) where {S}
113+
@inline _mapreduce(args::Vararg{Any,N}) where N = _mapfoldl(args...)
114+
115+
@generated function _mapfoldl(f, op, dims::Colon, init, ::Size{S}, a::StaticArray...) where {S}
95116
tmp = [:(a[$j][1]) for j 1:length(a)]
96117
expr = :(f($(tmp...)))
97-
for i 2:prod(S)
98-
tmp = [:(a[$j][$i]) for j 1:length(a)]
99-
expr = :(op($expr, f($(tmp...))))
100-
end
101-
return quote
102-
@_inline_meta
103-
@inbounds return $expr
118+
if init === _InitialValue
119+
expr = :(Base.reduce_first(op, $expr))
120+
else
121+
expr = :(op(init, $expr))
104122
end
105-
end
106-
107-
@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{(:init,)},
108-
::Size{S}, a::StaticArray...) where {S}
109-
expr = :(nt.init)
110-
for i 1:prod(S)
123+
for i 2:prod(S)
111124
tmp = [:(a[$j][$i]) for j 1:length(a)]
112125
expr = :(op($expr, f($(tmp...))))
113126
end
@@ -117,24 +130,24 @@ end
117130
end
118131
end
119132

120-
@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
133+
@inline function _mapreduce(f, op, D::Int, init, sz::Size{S}, a::StaticArray) where {S}
121134
# Body of this function is split because constant propagation (at least
122135
# as of Julia 1.2) can't always correctly propagate here and
123136
# as a result the function is not type stable and very slow.
124137
# This makes it at least fast for three dimensions but people should use
125138
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
126139
if D == 1
127-
return _mapreduce(f, op, Val(1), nt, sz, a)
140+
return _mapreduce(f, op, Val(1), init, sz, a)
128141
elseif D == 2
129-
return _mapreduce(f, op, Val(2), nt, sz, a)
142+
return _mapreduce(f, op, Val(2), init, sz, a)
130143
elseif D == 3
131-
return _mapreduce(f, op, Val(3), nt, sz, a)
144+
return _mapreduce(f, op, Val(3), init, sz, a)
132145
else
133-
return _mapreduce(f, op, Val(D), nt, sz, a)
146+
return _mapreduce(f, op, Val(D), init, sz, a)
134147
end
135148
end
136149

137-
@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
150+
@generated function _mapfoldl(f, op, dims::Val{D}, init,
138151
::Size{S}, a::StaticArray) where {S,D}
139152
N = length(S)
140153
Snew = ([n==D ? 1 : S[n] for n = 1:N]...,)
@@ -143,32 +156,12 @@ end
143156
itr = [1:n for n Snew]
144157
for i Base.product(itr...)
145158
expr = :(f(a[$(i...)]))
146-
for k = 2:S[D]
147-
ik = collect(i)
148-
ik[D] = k
149-
expr = :(op($expr, f(a[$(ik...)])))
159+
if init === _InitialValue
160+
expr = :(Base.reduce_first(op, $expr))
161+
else
162+
expr = :(op(init, $expr))
150163
end
151-
152-
exprs[i...] = expr
153-
end
154-
155-
return quote
156-
@_inline_meta
157-
@inbounds elements = tuple($(exprs...))
158-
@inbounds return similar_type(a, eltype(elements), Size($Snew))(elements)
159-
end
160-
end
161-
162-
@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{(:init,)},
163-
::Size{S}, a::StaticArray) where {S,D}
164-
N = length(S)
165-
Snew = ([n==D ? 1 : S[n] for n = 1:N]...,)
166-
167-
exprs = Array{Expr}(undef, Snew)
168-
itr = [1:n for n = Snew]
169-
for i Base.product(itr...)
170-
expr = :(nt.init)
171-
for k = 1:S[D]
164+
for k = 2:S[D]
172165
ik = collect(i)
173166
ik[D] = k
174167
expr = :(op($expr, f(a[$(ik...)])))
@@ -188,20 +181,37 @@ end
188181
## reduce ##
189182
############
190183

191-
@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)
184+
@inline reduce(op, a::StaticArray; dims = :, init = _InitialValue()) =
185+
_reduce(op, a, dims, init)
192186

193187
# disambiguation
194188
reduce(::typeof(vcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
195189
Base._typed_vcat(mapreduce(eltype, promote_type, A), A)
196190
reduce(::typeof(vcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
197-
_reduce(vcat, A, :, NamedTuple())
191+
_reduce(vcat, A, :, _InitialValue())
198192

199193
reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
200194
Base._typed_hcat(mapreduce(eltype, promote_type, A), A)
201195
reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
202-
_reduce(hcat, A, :, NamedTuple())
196+
_reduce(hcat, A, :, _InitialValue())
197+
198+
@inline _reduce(op, a::StaticArray, dims, init = _InitialValue()) =
199+
_mapreduce(identity, op, dims, init, Size(a), a)
203200

204-
@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
201+
################
202+
## (map)foldl ##
203+
################
204+
205+
# Using `where {R}` to force specialization. See:
206+
# https://docs.julialang.org/en/v1.5-dev/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing-1
207+
# https://github.com/JuliaLang/julia/pull/33917
208+
209+
@inline mapfoldl(f::F, op::R, a::StaticArray; init = _InitialValue()) where {F,R} =
210+
_mapfoldl(f, op, :, init, Size(a), a)
211+
@inline foldl(op::R, a::StaticArray; init = _InitialValue()) where {R} =
212+
_foldl(op, a, :, init)
213+
@inline _foldl(op::R, a, dims, init = _InitialValue()) where {R} =
214+
_mapfoldl(identity, op, dims, init, Size(a), a)
205215

206216
#######################
207217
## related functions ##
@@ -227,37 +237,37 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
227237
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)
228238

229239
@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
230-
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
231-
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity
240+
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a)
241+
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) # avoid ambiguity
232242

233243
@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
234-
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
235-
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
244+
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)
245+
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)
236246

237247
@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
238-
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)
248+
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, _InitialValue(), Size(a), a)
239249

240-
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
241-
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)
250+
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions
251+
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a)
242252

243-
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
244-
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)
253+
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, false) # (benchmarking needed)
254+
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed)
245255

246-
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)
256+
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, false, Size(a), a)
247257

248258
_mean_denom(a, dims::Colon) = length(a)
249259
_mean_denom(a, dims::Int) = size(a, dims)
250260
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
251261
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)
252262

253263
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
254-
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)
264+
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)
255265

256266
@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
257-
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)
267+
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, _InitialValue(), Size(a), a)
258268

259269
@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
260-
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)
270+
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, _InitialValue(), Size(a), a)
261271

262272
# Diff is slightly different
263273
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)
@@ -286,8 +296,6 @@ end
286296
end
287297
end
288298

289-
struct _InitialValue end
290-
291299
_maybe_val(dims::Integer) = Val(Int(dims))
292300
_maybe_val(dims) = dims
293301
_valof(::Val{D}) where D = D
@@ -299,19 +307,18 @@ _valof(::Val{D}) where D = D
299307
_accumulate(op, a, _maybe_val(dims), init)
300308

301309
@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
302-
# Adjoin the initial value to `op`:
310+
# Adjoin the initial value to `op` (one-line version of `Base.BottomRF`):
303311
rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)
304312

305313
if isempty(a)
306314
T = return_type(rf, Tuple{typeof(init), eltype(a)})
307315
return similar_type(a, T)()
308316
end
309317

310-
# StaticArrays' `reduce` is `foldl`:
311-
results = _reduce(
318+
results = _foldl(
312319
a,
313320
dims,
314-
(init = (similar_type(a, Union{}, Size(0))(), init),),
321+
(similar_type(a, Union{}, Size(0))(), init),
315322
) do (ys, acc), x
316323
y = rf(acc, x)
317324
# Not using `push(ys, y)` here since we need to widen element type as

test/mapreduce.jl

+9
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ using Statistics: mean
4949
@test mapreduce(x->x^2, max, sa; dims=2, init=-1.) == SMatrix{I,1}(mapreduce(x->x^2, max, a, dims=2, init=-1.))
5050
end
5151

52+
@testset "[map]foldl" begin
53+
a = rand(4,3)
54+
v1 = [2,4,6,8]; sv1 = SVector{4}(v1)
55+
@test foldl(+, sv1) === foldl(+, v1)
56+
@test foldl(+, sv1; init=0) === foldl(+, v1; init=0)
57+
@test mapfoldl(-, +, sv1) === mapfoldl(-, +, v1)
58+
@test mapfoldl(-, +, sv1; init=0) === mapfoldl(-, +, v1, init=0)
59+
end
60+
5261
@testset "implemented by [map]reduce and [map]reducedim" begin
5362
I, J, K = 2, 2, 2
5463
OSArray = SArray{Tuple{I,J,K}} # original

0 commit comments

Comments
 (0)