1
+ """
2
+ _InitialValue
3
+
4
+ A singleton type for representing "universal" initial value (identity element).
5
+
6
+ The idea is that, given `op` for `mapfoldl`, virtually, we define an "extended"
7
+ version of it by
8
+
9
+ op′(::_InitialValue, x) = x
10
+ op′(acc, x) = op(acc, x)
11
+
12
+ This is just a conceptually useful model to have in mind and we don't actually
13
+ define `op′` here (yet?). But see `Base.BottomRF` for how it might work in
14
+ action.
15
+
16
+ (It is related to that you can always turn a semigroup without an identity into
17
+ a monoid by "adjoining" an element that acts as the identity.)
18
+ """
19
+ struct _InitialValue end
20
+
1
21
@inline _first (a1, as... ) = a1
2
22
3
23
# ###############
86
106
# # mapreduce ##
87
107
# ##############
88
108
89
- @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims= :,kw ... )
90
- _mapreduce (f, op, dims, kw . data , same_size (a, b... ), a, b... )
109
+ @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims= :, init = _InitialValue () )
110
+ _mapreduce (f, op, dims, init , same_size (a, b... ), a, b... )
91
111
end
92
112
93
- @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{()} ,
94
- :: Size{S} , a:: StaticArray... ) where {S}
113
+ @inline _mapreduce (args:: Vararg{Any,N} ) where N = _mapfoldl (args... )
114
+
115
+ @generated function _mapfoldl (f, op, dims:: Colon , init, :: Size{S} , a:: StaticArray... ) where {S}
95
116
tmp = [:(a[$ j][1 ]) for j ∈ 1 : length (a)]
96
117
expr = :(f ($ (tmp... )))
97
- for i ∈ 2 : prod (S)
98
- tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
99
- expr = :(op ($ expr, f ($ (tmp... ))))
100
- end
101
- return quote
102
- @_inline_meta
103
- @inbounds return $ expr
118
+ if init === _InitialValue
119
+ expr = :(Base. reduce_first (op, $ expr))
120
+ else
121
+ expr = :(op (init, $ expr))
104
122
end
105
- end
106
-
107
- @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{(:init,)} ,
108
- :: Size{S} , a:: StaticArray... ) where {S}
109
- expr = :(nt. init)
110
- for i ∈ 1 : prod (S)
123
+ for i ∈ 2 : prod (S)
111
124
tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
112
125
expr = :(op ($ expr, f ($ (tmp... ))))
113
126
end
@@ -117,24 +130,24 @@ end
117
130
end
118
131
end
119
132
120
- @inline function _mapreduce (f, op, D:: Int , nt :: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S}
133
+ @inline function _mapreduce (f, op, D:: Int , init , sz:: Size{S} , a:: StaticArray ) where {S}
121
134
# Body of this function is split because constant propagation (at least
122
135
# as of Julia 1.2) can't always correctly propagate here and
123
136
# as a result the function is not type stable and very slow.
124
137
# This makes it at least fast for three dimensions but people should use
125
138
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
126
139
if D == 1
127
- return _mapreduce (f, op, Val (1 ), nt , sz, a)
140
+ return _mapreduce (f, op, Val (1 ), init , sz, a)
128
141
elseif D == 2
129
- return _mapreduce (f, op, Val (2 ), nt , sz, a)
142
+ return _mapreduce (f, op, Val (2 ), init , sz, a)
130
143
elseif D == 3
131
- return _mapreduce (f, op, Val (3 ), nt , sz, a)
144
+ return _mapreduce (f, op, Val (3 ), init , sz, a)
132
145
else
133
- return _mapreduce (f, op, Val (D), nt , sz, a)
146
+ return _mapreduce (f, op, Val (D), init , sz, a)
134
147
end
135
148
end
136
149
137
- @generated function _mapreduce (f, op, dims:: Val{D} , nt :: NamedTuple{()} ,
150
+ @generated function _mapfoldl (f, op, dims:: Val{D} , init ,
138
151
:: Size{S} , a:: StaticArray ) where {S,D}
139
152
N = length (S)
140
153
Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
@@ -143,32 +156,12 @@ end
143
156
itr = [1 : n for n ∈ Snew]
144
157
for i ∈ Base. product (itr... )
145
158
expr = :(f (a[$ (i... )]))
146
- for k = 2 : S[D]
147
- ik = collect (i )
148
- ik[D] = k
149
- expr = :(op ($ expr, f (a[ $ (ik ... )]) ))
159
+ if init === _InitialValue
160
+ expr = :(Base . reduce_first (op, $ expr) )
161
+ else
162
+ expr = :(op (init, $ expr ))
150
163
end
151
-
152
- exprs[i... ] = expr
153
- end
154
-
155
- return quote
156
- @_inline_meta
157
- @inbounds elements = tuple ($ (exprs... ))
158
- @inbounds return similar_type (a, eltype (elements), Size ($ Snew))(elements)
159
- end
160
- end
161
-
162
- @generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{(:init,)} ,
163
- :: Size{S} , a:: StaticArray ) where {S,D}
164
- N = length (S)
165
- Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
166
-
167
- exprs = Array {Expr} (undef, Snew)
168
- itr = [1 : n for n = Snew]
169
- for i ∈ Base. product (itr... )
170
- expr = :(nt. init)
171
- for k = 1 : S[D]
164
+ for k = 2 : S[D]
172
165
ik = collect (i)
173
166
ik[D] = k
174
167
expr = :(op ($ expr, f (a[$ (ik... )])))
@@ -188,20 +181,37 @@ end
188
181
# # reduce ##
189
182
# ###########
190
183
191
- @inline reduce (op, a:: StaticArray ; dims= :, kw... ) = _reduce (op, a, dims, kw. data)
184
+ @inline reduce (op, a:: StaticArray ; dims = :, init = _InitialValue ()) =
185
+ _reduce (op, a, dims, init)
192
186
193
187
# disambiguation
194
188
reduce (:: typeof (vcat), A:: StaticArray{<:Tuple,<:AbstractVecOrMat} ) =
195
189
Base. _typed_vcat (mapreduce (eltype, promote_type, A), A)
196
190
reduce (:: typeof (vcat), A:: StaticArray{<:Tuple,<:StaticVecOrMatLike} ) =
197
- _reduce (vcat, A, :, NamedTuple ())
191
+ _reduce (vcat, A, :, _InitialValue ())
198
192
199
193
reduce (:: typeof (hcat), A:: StaticArray{<:Tuple,<:AbstractVecOrMat} ) =
200
194
Base. _typed_hcat (mapreduce (eltype, promote_type, A), A)
201
195
reduce (:: typeof (hcat), A:: StaticArray{<:Tuple,<:StaticVecOrMatLike} ) =
202
- _reduce (hcat, A, :, NamedTuple ())
196
+ _reduce (hcat, A, :, _InitialValue ())
197
+
198
+ @inline _reduce (op, a:: StaticArray , dims, init = _InitialValue ()) =
199
+ _mapreduce (identity, op, dims, init, Size (a), a)
203
200
204
- @inline _reduce (op, a:: StaticArray , dims, kw:: NamedTuple = NamedTuple ()) = _mapreduce (identity, op, dims, kw, Size (a), a)
201
+ # ###############
202
+ # # (map)foldl ##
203
+ # ###############
204
+
205
+ # Using `where {R}` to force specialization. See:
206
+ # https://docs.julialang.org/en/v1.5-dev/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing-1
207
+ # https://github.com/JuliaLang/julia/pull/33917
208
+
209
+ @inline mapfoldl (f:: F , op:: R , a:: StaticArray ; init = _InitialValue ()) where {F,R} =
210
+ _mapfoldl (f, op, :, init, Size (a), a)
211
+ @inline foldl (op:: R , a:: StaticArray ; init = _InitialValue ()) where {R} =
212
+ _foldl (op, a, :, init)
213
+ @inline _foldl (op:: R , a, dims, init = _InitialValue ()) where {R} =
214
+ _mapfoldl (identity, op, dims, init, Size (a), a)
205
215
206
216
# ######################
207
217
# # related functions ##
@@ -227,37 +237,37 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
227
237
@inline iszero (a:: StaticArray{<:Tuple,T} ) where {T} = reduce ((x,y) -> x && iszero (y), a, init= true )
228
238
229
239
@inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (+ , a, dims)
230
- @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a)
231
- @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) # avoid ambiguity
240
+ @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, _InitialValue (), Size (a), a)
241
+ @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, _InitialValue (), Size (a), a) # avoid ambiguity
232
242
233
243
@inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (* , a, dims)
234
- @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a)
235
- @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a)
244
+ @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, _InitialValue (), Size (a), a)
245
+ @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, _InitialValue (), Size (a), a)
236
246
237
247
@inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (+ , a, dims)
238
- @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, NamedTuple (), Size (a), a)
248
+ @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, _InitialValue (), Size (a), a)
239
249
240
- @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, (init = true ,) ) # non-branching versions
241
- @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, (init = true ,) , Size (a), a)
250
+ @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, true ) # non-branching versions
251
+ @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, true , Size (a), a)
242
252
243
- @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, (init = false ,) ) # (benchmarking needed)
244
- @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, (init = false ,) , Size (a), a) # (benchmarking needed)
253
+ @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, false ) # (benchmarking needed)
254
+ @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, false , Size (a), a) # (benchmarking needed)
245
255
246
- @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, (init = false ,) , Size (a), a)
256
+ @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, false , Size (a), a)
247
257
248
258
_mean_denom (a, dims:: Colon ) = length (a)
249
259
_mean_denom (a, dims:: Int ) = size (a, dims)
250
260
_mean_denom (a, :: Val{D} ) where {D} = size (a, D)
251
261
_mean_denom (a, :: Type{Val{D}} ) where {D} = size (a, D)
252
262
253
263
@inline mean (a:: StaticArray ; dims= :) = _reduce (+ , a, dims) / _mean_denom (a, dims)
254
- @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) / _mean_denom (a, dims)
264
+ @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, _InitialValue (), Size (a), a) / _mean_denom (a, dims)
255
265
256
266
@inline minimum (a:: StaticArray ; dims= :) = _reduce (min, a, dims) # base has mapreduce(idenity, scalarmin, a)
257
- @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, NamedTuple (), Size (a), a)
267
+ @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, _InitialValue (), Size (a), a)
258
268
259
269
@inline maximum (a:: StaticArray ; dims= :) = _reduce (max, a, dims) # base has mapreduce(idenity, scalarmax, a)
260
- @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, NamedTuple (), Size (a), a)
270
+ @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, _InitialValue (), Size (a), a)
261
271
262
272
# Diff is slightly different
263
273
@inline diff (a:: StaticArray ; dims) = _diff (Size (a), a, dims)
286
296
end
287
297
end
288
298
289
- struct _InitialValue end
290
-
291
299
_maybe_val (dims:: Integer ) = Val (Int (dims))
292
300
_maybe_val (dims) = dims
293
301
_valof (:: Val{D} ) where D = D
@@ -299,19 +307,18 @@ _valof(::Val{D}) where D = D
299
307
_accumulate (op, a, _maybe_val (dims), init)
300
308
301
309
@inline function _accumulate (op:: F , a:: StaticArray , dims:: Union{Val,Colon} , init) where {F}
302
- # Adjoin the initial value to `op`:
310
+ # Adjoin the initial value to `op` (one-line version of `Base.BottomRF`) :
303
311
rf (x, y) = x isa _InitialValue ? Base. reduce_first (op, y) : op (x, y)
304
312
305
313
if isempty (a)
306
314
T = return_type (rf, Tuple{typeof (init), eltype (a)})
307
315
return similar_type (a, T)()
308
316
end
309
317
310
- # StaticArrays' `reduce` is `foldl`:
311
- results = _reduce (
318
+ results = _foldl (
312
319
a,
313
320
dims,
314
- (init = ( similar_type (a, Union{}, Size (0 ))(), init), ),
321
+ (similar_type (a, Union{}, Size (0 ))(), init),
315
322
) do (ys, acc), x
316
323
y = rf (acc, x)
317
324
# Not using `push(ys, y)` here since we need to widen element type as
0 commit comments