Skip to content

Commit 7a66328

Browse files
committed
multiple dimensions for reduce only, not mapreduce
1 parent 28202e3 commit 7a66328

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/mapreduce.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,8 @@ end
158158
end
159159
end
160160

161-
@inline function _mapreduce(f, op, D::Tuple, init, sz::Size{S}, a::StaticArray) where {S}
162-
b = _mapreduce(f, op, first(D), init, sz, a)
163-
return _mapreduce(f, op, Base.tail(D), init, Size(b), b)
164-
end
165-
_mapreduce(f, op, D::Tuple{}, init, sz::Size{S}, a::StaticArray) where {S} = a
161+
@inline _mapreduce(f, op, D::Tuple{<:Any}, init, sz::Size{S}, a::StaticArray) where {S} =
162+
_mapreduce(f, op, first(D), init, sz, a)
166163

167164
@generated function _mapfoldl(f, op, dims::Val{D}, init,
168165
::Size{S}, a::StaticArray) where {S,D}
@@ -215,6 +212,14 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
215212
@inline _reduce(op, a::StaticArray, dims, init = _InitialValue()) =
216213
_mapreduce(identity, op, dims, init, Size(a), a)
217214

215+
@inline function _reduce(op, a::StaticArray, dims::Tuple, init = _InitialValue())
216+
b = _reduce(op, a, first(dims), init)
217+
return _reduce(op, b, Base.tail(dims))
218+
end
219+
_reduce(op, a::StaticArray, dims::Tuple{}, ::_InitialValue) = a
220+
_reduce(op, a::StaticArray, dims::Tuple{}, init) = op.(init, a)
221+
222+
218223
################
219224
## (map)foldl ##
220225
################

test/mapreduce.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ using Statistics: mean
114114
@test sum(sa, dims=2) === RSArray2(sum(a, dims=2))
115115
@test sum(sa, dims=(2,)) === RSArray2(sum(a, dims=2))
116116
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
117+
@test sum(sa, dims=(1,3)) === RSArray13(sum(a, dims=(1,3)))
117118
@test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2))
119+
@test sum(abs2, sa; dims=(2,)) === RSArray2(sum(abs2, a, dims=2))
118120
@test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2))
119-
@test_broken sum(abs2, sa; dims=(1,3)) === RSArray13(sum(abs2, a, dims=(1,3)))
120121

121122
@test prod(sa) === prod(a)
122123
@test prod(abs2, sa) === prod(abs2, a)

0 commit comments

Comments
 (0)