Skip to content

Commit 04a92fc

Browse files
authored
Merge pull request #234 from JuliaDiff/ox/responsibleadd
Make add!! decide if mutable or not
2 parents d981e09 + a3bd419 commit 04a92fc

File tree

8 files changed

+225
-60
lines changed

8 files changed

+225
-60
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.9.16"
3+
version = "0.9.17"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
8+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
89

910
[compat]
1011
BenchmarkTools = "0.5"

docs/Manifest.toml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
1010
version = "0.5.10"
1111

1212
[[ChainRulesCore]]
13-
deps = ["LinearAlgebra", "MuladdMacro"]
13+
deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"]
1414
path = ".."
1515
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
16-
version = "0.9.12"
16+
version = "0.9.17"
1717

1818
[[Dates]]
1919
deps = ["Printf"]
@@ -81,10 +81,10 @@ uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
8181
version = "0.2.2"
8282

8383
[[Parsers]]
84-
deps = ["Dates", "Test"]
85-
git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e"
84+
deps = ["Dates"]
85+
git-tree-sha1 = "6fa4202675c05ba0f8268a6ddf07606350eda3ce"
8686
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
87-
version = "1.0.10"
87+
version = "1.0.11"
8888

8989
[[Pkg]]
9090
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@@ -117,6 +117,10 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
117117
[[Sockets]]
118118
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
119119

120+
[[SparseArrays]]
121+
deps = ["LinearAlgebra", "Random"]
122+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
123+
120124
[[Test]]
121125
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
122126
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/src/api.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ Private = false
2828
```
2929

3030
## Accumulation
31-
```@autodocs
32-
Modules = [ChainRulesCore]
33-
Pages = ["accumulation.jl"]
34-
Private = false
31+
```@docs
32+
add!!
33+
ChainRulesCore.is_inplaceable_destination
3534
```
3635

3736
## Ruleset Loading

docs/src/debug_mode.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,8 @@ To enable, redefine the [`ChainRulesCore.debug_mode`](@ref) function to return `
1111
```julia
1212
ChainRulesCore.debug_mode() = true
1313
```
14+
15+
## Features of Debug Mode:
16+
17+
- If you add a `Composite` to a primal value, and it was unable to construct a new primal values, then a better error message will be displayed detailing what overloads need to be written to fix this.
18+
- during [`add!!`](@ref), if an `InplaceThunk` is used, and it runs the code that is supposed to run in place, but the return result is not the input (with updated values), then an error is thrown. Rather than silently using what ever values were returned.

src/ChainRulesCore.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ChainRulesCore
22
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
33
using LinearAlgebra: LinearAlgebra
4+
using SparseArrays: SparseVector, SparseMatrixCSC
45
using MuladdMacro: @muladd
56

67
export on_new_rule, refresh_rules # generation tools

src/accumulation.jl

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,90 @@
33
44
Returns `x+y`, potentially mutating `x` in-place to hold this value.
55
This avoids allocations when `x` can be mutated in this way.
6-
7-
See also: [`InplaceableThunk`](@ref).
86
"""
97
add!!(x, y) = x + y
108

11-
add!!(x, t::InplaceableThunk) = t.add!(x)
9+
"""
10+
add!!(x, t::ImplacableThunk)
11+
12+
The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only call
13+
`t.add!` on `x` if `x` is suitably mutable; otherwise it will be out of place.
14+
"""
15+
function add!!(x, t::InplaceableThunk)
16+
return if is_inplaceable_destination(x)
17+
if !debug_mode()
18+
t.add!(x)
19+
else
20+
debug_add!(x, t)
21+
end
22+
else
23+
x + t
24+
end
25+
end
26+
27+
function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N
28+
return if is_inplaceable_destination(x)
29+
x .+= y
30+
else
31+
x + y
32+
end
33+
end
34+
35+
36+
"""
37+
is_inplaceable_destination(x) -> Bool
38+
39+
Returns true if `x` is suitable for for storing inplace accumulation of gradients.
40+
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
41+
differential.
42+
Wrapper array types do not need to overload this if they overload `Base.parent`, and are
43+
`is_inplaceable_destination` if and only if their parent array is.
44+
Other types should overload this, as it defaults to `false`.
45+
"""
46+
is_inplaceable_destination(::Any) = false
47+
is_inplaceable_destination(::Array) = true
48+
is_inplaceable_destination(::SparseVector) = true
49+
is_inplaceable_destination(::SparseMatrixCSC) = true
50+
is_inplaceable_destination(::BitArray) = true
51+
function is_inplaceable_destination(x::AbstractArray)
52+
p = parent(x)
53+
p === x && return false # no parent
54+
# basically all wrapper types delegate `setindex!` to their `parent` after some
55+
# processing and so are mutable if their `parent` is.
56+
return is_inplaceable_destination(p)
57+
end
58+
59+
# Hermitian and Symmetric are too fussy to deal with right now
60+
# https://github.com/JuliaLang/julia/issues/38056
61+
# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236
62+
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false
63+
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false
64+
65+
66+
function debug_add!(accumuland, t::InplaceableThunk)
67+
returned_value = t.add!(accumuland)
68+
if returned_value !== accumuland
69+
throw(BadInplaceException(t, accumuland, returned_value))
70+
end
71+
return returned_value
72+
end
73+
74+
struct BadInplaceException <: Exception
75+
ithunk::InplaceableThunk
76+
accumuland
77+
returned_value
78+
end
79+
80+
function Base.showerror(io::IO, err::BadInplaceException)
81+
println(io, "`add!!(accumuland, ithunk))` did not return an updated accumuland.")
82+
println(io, "ithunk = $(err.ithunk)")
83+
println(io, "accumuland = $(err.accumuland)")
84+
println(io, "returned_value = $(err.returned_value)")
1285

13-
function add!!(x::Array{<:Any, N}, y::AbstractArray{<:Any, N}) where N
14-
return x .+= y
86+
if err.accumuland == err.returned_value
87+
println(
88+
io,
89+
"Which in this case happenned to be equal. But they are not the same object."
90+
)
91+
end
1592
end

test/accumulation.jl

Lines changed: 121 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,139 @@
11
@testset "accumulation.jl" begin
2-
@testset "scalar" begin
3-
@test 16 == add!!(12, 4)
4-
end
2+
@testset "is_inplaceable_destination" begin
3+
is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination
4+
5+
@test is_inplaceable_destination([1, 2, 3, 4])
6+
@test !is_inplaceable_destination(1:4)
7+
8+
@test is_inplaceable_destination(Diagonal([1, 2, 3, 4]))
9+
@test !is_inplaceable_destination(Diagonal(1:4))
510

6-
@testset "Differentials" begin
7-
@test 16 == add!!(12, @thunk(2*2))
8-
@test 16 == add!!(16, Zero())
11+
@test is_inplaceable_destination(view([1, 2, 3, 4], :, :))
12+
@test !is_inplaceable_destination(view(1:4, :, :))
913

10-
@test 16 == add!!(16, DoesNotExist()) # Should this be an error?
14+
@test is_inplaceable_destination(falses(4))
15+
@test is_inplaceable_destination(spzeros(4))
16+
@test is_inplaceable_destination(spzeros(2, 2))
17+
18+
@test !is_inplaceable_destination(1.3)
19+
@test !is_inplaceable_destination(@SVector [1, 2, 3])
20+
@test !is_inplaceable_destination(Hermitian([1 2; 2 4]))
21+
@test !is_inplaceable_destination(Symmetric([1 2; 2 4]))
1122
end
1223

13-
@testset "Array" begin
14-
@testset "Happy Path" begin
15-
@testset "RHS Array" begin
16-
A = [1.0 2.0; 3.0 4.0]
17-
result = -1.0*ones(2,2)
18-
ret = add!!(result, A)
19-
@test ret === result # must be same object
20-
@test result == [0.0 1.0; 2.0 3.0]
24+
@testset "add!!" begin
25+
@testset "scalar" begin
26+
@test 16 == add!!(12, 4)
27+
end
28+
29+
@testset "misc AbstractDifferential subtypes" begin
30+
@test 16 == add!!(12, @thunk(2*2))
31+
@test 16 == add!!(16, Zero())
32+
33+
@test 16 == add!!(16, DoesNotExist()) # Should this be an error?
34+
end
35+
36+
@testset "add!!(::AbstractArray, ::AbstractArray)" begin
37+
@testset "LHS Array (inplace)" begin
38+
@testset "RHS Array" begin
39+
A = [1.0 2.0; 3.0 4.0]
40+
accumuland = -1.0*ones(2,2)
41+
ret = add!!(accumuland, A)
42+
@test ret === accumuland # must be same object
43+
@test accumuland == [0.0 1.0; 2.0 3.0]
44+
end
45+
46+
@testset "RHS StaticArray" begin
47+
A = @SMatrix[1.0 2.0; 3.0 4.0]
48+
accumuland = -1.0*ones(2,2)
49+
ret = add!!(accumuland, A)
50+
@test ret === accumuland # must be same object
51+
@test accumuland == [0.0 1.0; 2.0 3.0]
52+
end
53+
54+
@testset "RHS Diagonal" begin
55+
A = Diagonal([1.0, 2.0])
56+
accumuland = -1.0*ones(2,2)
57+
ret = add!!(accumuland, A)
58+
@test ret === accumuland # must be same object
59+
@test accumuland == [0.0 -1.0; -1.0 1.0]
60+
end
2161
end
2262

23-
@testset "RHS StaticArray" begin
24-
A = @SMatrix[1.0 2.0; 3.0 4.0]
25-
result = -1.0*ones(2,2)
26-
ret = add!!(result, A)
27-
@test ret === result # must be same object
28-
@test result == [0.0 1.0; 2.0 3.0]
63+
@testset "add!!(::StaticArray, ::Array) (out of place)" begin
64+
A = [1.0 2.0; 3.0 4.0]
65+
accumuland = @SMatrix [-1.0 -1.0; -1.0 -1.0]
66+
ret = add!!(accumuland, A)
67+
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
68+
@test ret !== accumuland # must not be same object
69+
@test accumuland == [-1.0 -1.0; -1.0 -1.0] # must not have changed
2970
end
3071

31-
@testset "RHS Diagonal" begin
72+
@testset "add!!(::Diagonal{<:Vector}, ::Diagonal{<:Vector}) (inplace)" begin
3273
A = Diagonal([1.0, 2.0])
33-
result = -1.0*ones(2,2)
34-
ret = add!!(result, A)
35-
@test ret === result # must be same object
36-
@test result == [0.0 -1.0; -1.0 1.0]
74+
accumuland = Diagonal([-2.0, -2.0])
75+
ret = add!!(accumuland, A)
76+
@test ret === accumuland # must be same object
77+
@test accumuland == Diagonal([-1.0, 0.0])
78+
end
79+
80+
@testset "Unhappy Path" begin
81+
# wrong length
82+
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2))
83+
# wrong shape
84+
@test_throws DimensionMismatch add!!(ones(4,4), ones(16))
85+
# wrong type (adding scalar to array)
86+
@test_throws MethodError add!!(ones(4), 21.0)
87+
end
88+
end
89+
90+
@testset "InplaceableThunk" begin
91+
ithunk = InplaceableThunk(
92+
@thunk(-1.0*ones(2, 2)),
93+
x -> x .-= ones(2, 2)
94+
)
95+
96+
@testset "in place" begin
97+
accumuland = [1.0 2.0; 3.0 4.0]
98+
ret = add!!(accumuland, ithunk)
99+
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
100+
@test ret === accumuland # must be same object
101+
end
102+
103+
@testset "out of place" begin
104+
accumuland = @SMatrix [1.0 2.0; 3.0 4.0]
105+
106+
ret = add!!(accumuland, ithunk)
107+
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
108+
@test ret !== accumuland # must not be same object
109+
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated
37110
end
38111
end
39112

40-
@testset "Unhappy Path" begin
41-
# wrong length
42-
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2))
43-
# wrong shape
44-
@test_throws DimensionMismatch add!!(ones(4,4), ones(16))
45-
# wrong type (adding scalar to array)
46-
@test_throws MethodError add!!(ones(4), 21.0)
113+
@testset "not actually inplace but said it was" begin
114+
ithunk = InplaceableThunk(
115+
@thunk(@assert false), # this should never be used in this test
116+
x -> 77*ones(2, 2) # not actually inplace (also wrong)
117+
)
118+
accumuland = ones(2, 2)
119+
@assert ChainRulesCore.debug_mode() == false
120+
# without debug being enabled should return the result, not error
121+
@test 77*ones(2, 2) == add!!(accumuland, ithunk)
122+
123+
ChainRulesCore.debug_mode() = true # enable debug mode
124+
# with debug being enabled should error
125+
@test_throws ChainRulesCore.BadInplaceException add!!(accumuland, ithunk)
126+
ChainRulesCore.debug_mode() = false # disable it again
47127
end
48128
end
49129

50-
@testset "InplaceableThunk" begin
51-
A=[1.0 2.0; 3.0 4.0]
52-
ithunk = InplaceableThunk(
53-
@thunk(A*B),
54-
x -> x.+=A
55-
)
56-
57-
result = -1.0*ones(2,2)
58-
ret = add!!(result, ithunk)
59-
@test ret === result # must be same object
60-
@test result == [0.0 1.0; 2.0 3.0]
130+
@testset "showerror BadInplaceException" begin
131+
BadInplaceException = ChainRulesCore.BadInplaceException
132+
ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing)
133+
msg = sprint(showerror, BadInplaceException(ithunk, [22], [23]))
134+
@test occursin("22", msg)
135+
136+
msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22]))
137+
@test occursin("equal", msg_equal)
61138
end
62139
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
using Base.Broadcast: broadcastable
22
using BenchmarkTools
33
using ChainRulesCore
4-
using LinearAlgebra: Diagonal, dot
4+
using LinearAlgebra: Diagonal, dot, Hermitian, Symmetric
55
using StaticArrays
6+
using SparseArrays
67
using Test
78

89
@testset "ChainRulesCore" begin

0 commit comments

Comments
 (0)