Skip to content

Commit 7904019

Browse files
authored
Delete special case for sum(f, ::Adjoint) (#618)
* fix 530 * bump version
1 parent cc8b9ea commit 7904019

File tree

3 files changed

+7
-22
lines changed

3 files changed

+7
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.31.0"
3+
version = "1.32.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Base/mapreduce.jl

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,6 @@ end
5151
##### `sum(f, x)`
5252
#####
5353

54-
# Can't map over Adjoint/Transpose Vector
55-
function rrule(
56-
config::RuleConfig{>:HasReverseMode},
57-
::typeof(sum),
58-
f,
59-
xs::Union{Adjoint{<:Number,<:AbstractVector},Transpose{<:Number,<:AbstractVector}};
60-
kwargs...
61-
)
62-
op = xs isa Adjoint ? adjoint : transpose
63-
# since summing a vector we don't need to worry about dims which simplifies adjointing
64-
vector = parent(xs)
65-
y, vector_sum_pb = rrule(config, sum, f, vector; kwargs...)
66-
function covector_sum_pb(ȳ)
67-
s̄um, f̄, v̄ = vector_sum_pb(ȳ)
68-
return s̄um, f̄, op(v̄)
69-
end
70-
71-
return y, covector_sum_pb
72-
end
73-
7454
function rrule(
7555
config::RuleConfig{>:HasReverseMode},
7656
::typeof(sum),
@@ -96,7 +76,8 @@ function rrule(
9676
# see `f.(xs)` but we don't need the pullbacks. Not implemented at present.)
9777

9878
# In the general case, we need to save all the pullbacks:
99-
fx_and_pullbacks = map(xᵢ -> rrule_via_ad(config, f, xᵢ), xs)
79+
# (Here `map` or `broadcast` would fail for adjoint vectors.)
80+
fx_and_pullbacks = [rrule_via_ad(config, f, xᵢ) for xᵢ in xs]
10081
y = sum(first, fx_and_pullbacks; dims)
10182

10283
function sum_pullback_f2(dy)

test/rulesets/Base/mapreduce.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ const CFG = ChainRulesTestUtils.ADviaRuleConfig()
9595
test_rrule(sum, inv, x[1, :]')
9696
test_rrule(sum, inv, x[1:1, :]')
9797
test_rrule(sum, inv, transpose(view(x, 1, :)))
98+
# Cases from https://github.com/JuliaDiff/ChainRules.jl/issues/530
99+
test_rrule(sum, log, [1, 2, 3.0]'; fkwargs=(;dims=1))
100+
test_rrule(sum, log, [1, 2, 3.0]'; fkwargs=(;dims=2))
101+
test_rrule(sum, imag, [1+2im, 3+4.0im]')
98102

99103
# Make sure we preserve type for StaticArrays
100104
_, pb = rrule(CFG, sum, abs, @SVector[1.0, -3.0])

0 commit comments

Comments
 (0)