@@ -170,27 +170,26 @@ end
170
170
# # Base interface
171
171
172
172
Base. _accumulate! (op, output:: WrappedMtlArray , input:: WrappedMtlVector , dims:: Nothing , init:: Nothing ) =
173
- scan ! (op, output, input; dims= 1 )
173
+ @inline AK . accumulate ! (op, output, input; dims, init = AK . neutral_element (op, eltype (output)), alg = AK . ScanPrefixes () )
174
174
175
175
Base. _accumulate! (op, output:: WrappedMtlArray , input:: WrappedMtlArray , dims:: Integer , init:: Nothing ) =
176
- scan! (op, output, input; dims= dims)
177
-
176
+ @inline AK. accumulate! (op, output, input; dims, init= AK. neutral_element (op, eltype (output)), alg= AK. ScanPrefixes ())
178
177
Base. _accumulate! (op, output:: WrappedMtlArray , input:: MtlVector , dims:: Nothing , init:: Some ) =
179
- scan ! (op, output, input; dims= 1 , init= init)
178
+ @inline AK . accumulate ! (op, output, input; dims, init= something ( init), alg = AK . ScanPrefixes () )
180
179
181
180
Base. _accumulate! (op, output:: WrappedMtlArray , input:: WrappedMtlArray , dims:: Integer , init:: Some ) =
182
- scan ! (op, output, input; dims= dims , init= init)
181
+ @inline AK . accumulate ! (op, output, input; dims, init= something ( init), alg = AK . ScanPrefixes () )
183
182
184
- Base. accumulate_pairwise! (op, result:: WrappedMtlVector , v:: WrappedMtlVector ) = accumulate! (op, result, v)
183
+ Base. accumulate_pairwise! (op, result:: WrappedMtlVector , v:: WrappedMtlVector ) = @inline AK . accumulate! (op, result, v; init = AK . neutral_element (op, eltype (result)), alg = AK . ScanPrefixes () )
185
184
186
185
# default behavior unless dims are specified by the user
187
186
function Base. accumulate (op, A:: WrappedMtlArray ;
188
187
dims:: Union{Nothing,Integer} = nothing , kw... )
188
+ nt = values (kw)
189
189
if dims === nothing && ! (A isa AbstractVector)
190
190
# This branch takes care of the cases not handled by `_accumulate!`.
191
- return reshape (accumulate (op, A[:]; kw ... ), size (A))
191
+ return reshape (AK . accumulate (op, A[:]; init = ( :init in keys (kw) ? nt . init : AK . neutral_element (op, eltype (A))), alg = AK . ScanPrefixes () ), size (A))
192
192
end
193
- nt = values (kw)
194
193
if isempty (kw)
195
194
out = similar (A, Base. promote_op (op, eltype (A), eltype (A)))
196
195
elseif keys (nt) === (:init ,)
0 commit comments