Skip to content

Commit 28202e3

Browse files
committed
allow dims::Tuple in sum
1 parent 9a9ddac commit 28202e3

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/mapreduce.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,12 @@ end
158158
end
159159
end
160160

161+
@inline function _mapreduce(f, op, D::Tuple, init, sz::Size{S}, a::StaticArray) where {S}
162+
b = _mapreduce(f, op, first(D), init, sz, a)
163+
return _mapreduce(f, op, Base.tail(D), init, Size(b), b)
164+
end
165+
_mapreduce(f, op, D::Tuple{}, init, sz::Size{S}, a::StaticArray) where {S} = a
166+
161167
@generated function _mapfoldl(f, op, dims::Val{D}, init,
162168
::Size{S}, a::StaticArray) where {S,D}
163169
N = length(S)

test/mapreduce.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ using Statistics: mean
102102
RSArray1 = SArray{Tuple{1,J,K}} # reduced in dimension 1
103103
RSArray2 = SArray{Tuple{I,1,K}} # reduced in dimension 2
104104
RSArray3 = SArray{Tuple{I,J,1}} # reduced in dimension 3
105+
RSArray13 = SArray{Tuple{1,J,1}} # reduced in dimension 1 and 3
105106
a = randn(I,J,K); sa = OSArray(a)
106107
b = rand(Bool,I,J,K); sb = OSArray(b)
107108
z = zeros(I,J,K); sz = OSArray(z)
@@ -111,9 +112,11 @@ using Statistics: mean
111112
@test sum(sa) === sum(a)
112113
@test sum(abs2, sa) === sum(abs2, a)
113114
@test sum(sa, dims=2) === RSArray2(sum(a, dims=2))
115+
@test sum(sa, dims=(2,)) === RSArray2(sum(a, dims=2))
114116
@test sum(sa, dims=Val(2)) === RSArray2(sum(a, dims=2))
115117
@test sum(abs2, sa; dims=2) === RSArray2(sum(abs2, a, dims=2))
116118
@test sum(abs2, sa; dims=Val(2)) === RSArray2(sum(abs2, a, dims=2))
119+
@test_broken sum(abs2, sa; dims=(1,3)) === RSArray13(sum(abs2, a, dims=(1,3)))
117120

118121
@test prod(sa) === prod(a)
119122
@test prod(abs2, sa) === prod(abs2, a)

0 commit comments

Comments
 (0)