Skip to content

Commit d246d12

Browse files
authored
Merge pull request #619 from JuliaDiff/revert-615-mean_f_x
Revert "Rule for `mean(f,x)`"
2 parents 7904019 + ac2dd28 commit d246d12

File tree

4 files changed

+12
-72
lines changed

4 files changed

+12
-72
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.32.0"
3+
version = "1.32.1"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/Statistics/statistics.jl

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,19 @@ _denom(x, dims::Colon) = length(x)
66
_denom(x, dims::Integer) = size(x, dims)
77
_denom(x, dims) = mapreduce(i->size(x, i), Base.mul_prod, unique(dims), init=1)
88

9-
function rrule(::typeof(mean), x::AbstractArray{<:Union{Real,Complex,AbstractArray}}; dims=:)
10-
y_sum, sum_pullback = rrule(sum, x; dims)
9+
# TODO: We have `mean(f, x; dims)` as of 1.3.0-DEV.36
10+
# https://github.com/JuliaDiff/ChainRules.jl/issues/85
11+
function rrule(::typeof(mean), x::AbstractArray{<:Real}; dims=:)
12+
y_sum, sum_pullback = rrule(sum, x; dims=dims)
1113
n = _denom(x, dims)
1214
function mean_pullback(ȳ)
13-
_, ∂x = sum_pullback(unthunk(ȳ) / n)
15+
_, ∂sum_x = sum_pullback(ȳ)
16+
∂x = unthunk(∂sum_x) / n
1417
return (NoTangent(), ∂x)
1518
end
1619
return y_sum / n, mean_pullback
1720
end
1821

19-
function rrule(
20-
config::RuleConfig{>:HasReverseMode},
21-
::typeof(mean),
22-
f::F,
23-
x::AbstractArray{T};
24-
dims=:,
25-
) where {F, T<:Union{Real,Complex,AbstractArray}}
26-
y_sum, sum_pullback = rrule(config, sum, f, x; dims)
27-
n = _denom(x, dims)
28-
function mean_pullback_f(ȳ)
29-
return sum_pullback(unthunk(ȳ) / n)
30-
end
31-
return y_sum / n, mean_pullback_f
32-
end
33-
3422
#####
3523
##### variance
3624
#####

test/rulesets/Statistics/statistics.jl

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,11 @@
11
@testset "mean" begin
2-
@testset "mean(x)" begin
3-
test_rrule(mean, randn(9))
4-
test_rrule(mean, randn(ComplexF64,2,4))
5-
test_rrule(mean, transpose(rand(3)))
6-
test_rrule(mean, [rand(3) for _ in 1:4]; check_inferred=false)
2+
n = 9
3+
@testset "Basic" begin
4+
test_rrule(mean, randn(n))
75
end
86
@testset "with dims kwargs" begin
9-
test_rrule(mean, randn(9); fkwargs=(;dims=1))
10-
test_rrule(mean, randn(9,4); fkwargs=(;dims=2))
11-
test_rrule(mean, [rand(2) for _ in 1:3, _ in 1:4]; fkwargs=(;dims=2), check_inferred=false)
12-
end
13-
@testset "mean(f, x)" begin
14-
# This shares its implementation with sum(f, x). Similar tests should cover all cases:
15-
test_rrule(mean, abs, [-4.0, 2.0, 2.0])
16-
test_rrule(mean, log, rand(3, 4) .+ 1)
17-
test_rrule(mean, cbrt, randn(5))
18-
test_rrule(mean, Multiplier(2.0), [2.0, 4.0, 8.0]) # defined in test_helpers.jl
19-
test_rrule(mean, Divider(1 + rand()), randn(5))
20-
21-
test_rrule(mean, sum, [[2.0, 4.0], [4.0,1.9]]; check_inferred=false)
22-
23-
test_rrule(mean, log, rand(ComplexF64, 5))
24-
test_rrule(mean, sqrt, rand(ComplexF64, 5))
25-
test_rrule(mean, abs, rand(ComplexF64, 3, 4))
26-
27-
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=1))
28-
test_rrule(mean, abs, [-2.0 4.0; 5.0 1.9]; fkwargs=(;dims=2))
29-
test_rrule(mean, sqrt, rand(ComplexF64, 3, 4); fkwargs=(;dims=(1,)))
7+
test_rrule(mean, randn(n); fkwargs=(;dims=1))
8+
test_rrule(mean, randn(n,4); fkwargs=(;dims=2))
309
end
3110
end
3211

test/test_helpers.jl

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,6 @@ function ChainRulesCore.rrule(m::Multiplier, y, z)
3535
return m(y, z), Multiplier_pullback_3
3636
end
3737

38-
"""
39-
Divider(x)
40-
41-
Stores a fixed `x` and divides by it, then squares the result.
42-
43-
Especially for testing the gradient of higher order functions with respect to `x`.
44-
```
45-
julia> map(Divider(2), [1 2 3 4 10])
46-
1×5 Matrix{Float64}:
47-
0.25 1.0 2.25 4.0 25.0
48-
```
49-
"""
50-
struct Divider{T<:Real}
51-
x::T
52-
end
53-
(d::Divider)(y::Real) = (y / d.x)^2
54-
55-
function ChainRulesCore.rrule(d::Divider, y::Real)
56-
Divider_pullback(dΩ) = (Tangent{typeof(d)}(; x = -2 ** y^2 / d.x^3), 2 ** y / d.x^2)
57-
return d(y), Divider_pullback
58-
end
59-
6038
"""
6139
Counter()
6240
@@ -110,11 +88,6 @@ end
11088
test_rrule(Multiplier(1.0 + 2im), 3.0 + 4im, 5.0 - 6im)
11189
test_rrule(Multiplier(rand(2,3)), rand(3,4), rand(4,5))
11290
end
113-
114-
@testset "Divider" begin
115-
test_rrule(Divider(2.3), 4.5)
116-
test_rrule(Divider(0.2), -3.4)
117-
end
11891

11992
@testset "Counter" begin
12093
c = Counter()

0 commit comments

Comments
 (0)