Skip to content

Make add!! decide if mutable or not #234

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.16"
version = "0.9.17"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
BenchmarkTools = "0.5"
Expand Down
14 changes: 9 additions & 5 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.10"

[[ChainRulesCore]]
deps = ["LinearAlgebra", "MuladdMacro"]
deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"]
path = ".."
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.9.12"
version = "0.9.17"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -81,10 +81,10 @@ uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
version = "0.2.2"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e"
deps = ["Dates"]
git-tree-sha1 = "6fa4202675c05ba0f8268a6ddf07606350eda3ce"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.10"
version = "1.0.11"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -117,6 +117,10 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"

[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
7 changes: 3 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ Private = false
```

## Accumulation
```@autodocs
Modules = [ChainRulesCore]
Pages = ["accumulation.jl"]
Private = false
```@docs
add!!
ChainRulesCore.is_inplaceable_destination
```

## Ruleset Loading
Expand Down
5 changes: 5 additions & 0 deletions docs/src/debug_mode.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ To enable, redefine the [`ChainRulesCore.debug_mode`](@ref) function to return `
```julia
ChainRulesCore.debug_mode() = true
```

## Features of Debug Mode:

- If you add a `Composite` to a primal value, and it was unable to construct a new primal values, then a better error message will be displayed detailing what overloads need to be written to fix this.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detailing what overloads need to be written to fix this

Is this true? this suggest to me it'll print something like

"""
If `Foo` can be updated in place, define `ChainRulesCore.is_inplaceable_destination(::Foo) = true
"""

But as far as i can see that's not in the error message, or did i miss it (or misread this comment)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually documenting something older.
Its unrelated to the main content of this PR.
Not inplace add but simply +
Its this message

function Base.showerror(io::IO, err::PrimalAdditionFailedException{P}) where {P}
println(io, "Could not construct $P after addition.")
println(io, "This probably means no default constructor is defined.")
println(io, "Either define a default constructor")
printstyled(io, "$P(", join(propertynames(err.differential), ", "), ")", color=:blue)
println(io, "\nor overload")
printstyled(io,
"ChainRulesCore.construct(::Type{$P}, ::$(typeof(err.differential)))";
color=:blue
)
println(io, "\nor overload")
printstyled(io, "Base.:+(::$P, ::$(typeof(err.differential)))"; color=:blue)
println(io, "\nOriginal Exception:")
printstyled(io, err.original; color=:yellow)
println(io)
end

I just realized when writing this this PR that I there was no list of what Debug Mode did,
so I couldn't add the Inplace behavour to it.

- during [`add!!`](@ref), if an `InplaceThunk` is used, and it runs the code that is supposed to run in place, but the return result is not the input (with updated values), then an error is thrown. Rather than silently using what ever values were returned.
1 change: 1 addition & 0 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using LinearAlgebra: LinearAlgebra
using SparseArrays: SparseVector, SparseMatrixCSC
using MuladdMacro: @muladd

export on_new_rule, refresh_rules # generation tools
Expand Down
87 changes: 82 additions & 5 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,90 @@

Returns `x+y`, potentially mutating `x` in-place to hold this value.
This avoids allocations when `x` can be mutated in this way.

See also: [`InplaceableThunk`](@ref).
"""
add!!(x, y) = x + y

add!!(x, t::InplaceableThunk) = t.add!(x)
"""
add!!(x, t::ImplacableThunk)

The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only call
`t.add!` on `x` if `x` is suitably mutable; otherwise it will be out of place.
"""
function add!!(x, t::InplaceableThunk)
return if is_inplaceable_destination(x)
if !debug_mode()
t.add!(x)
else
debug_add!(x, t)
end
else
x + t
end
end

function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N
return if is_inplaceable_destination(x)
x .+= y
else
x + y
end
end


"""
is_inplaceable_destination(x) -> Bool

Returns true if `x` is suitable for for storing inplace accumulation of gradients.
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate
differential.
Wrapper array types do not need to overload this if they overload `Base.parent`, and are
`is_inplaceable_destination` if and only if their parent array is.
Other types should overload this, as it defaults to `false`.
"""
is_inplaceable_destination(::Any) = false
is_inplaceable_destination(::Array) = true
is_inplaceable_destination(::SparseVector) = true
is_inplaceable_destination(::SparseMatrixCSC) = true
is_inplaceable_destination(::BitArray) = true
function is_inplaceable_destination(x::AbstractArray)
p = parent(x)
p === x && return false # no parent
# basically all wrapper types delegate `setindex!` to their `parent` after some
# processing and so are mutable if their `parent` is.
return is_inplaceable_destination(p)
end

# Hermitian and Symmetric are too fussy to deal with right now
# https://github.com/JuliaLang/julia/issues/38056
# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false


function debug_add!(accumuland, t::InplaceableThunk)
returned_value = t.add!(accumuland)
if returned_value !== accumuland
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps this is cheap enough to do it outside of debug_mode?
Maybe if we find outselves having to tell people to turn on debug_mode to find errors we should look into that?
Or maybe we should look into it now.
What do you think?

throw(BadInplaceException(t, accumuland, returned_value))
end
return returned_value
end

struct BadInplaceException <: Exception
ithunk::InplaceableThunk
accumuland
returned_value
end

function Base.showerror(io::IO, err::BadInplaceException)
println(io, "`add!!(accumuland, ithunk))` did not return an updated accumuland.")
println(io, "ithunk = $(err.ithunk)")
println(io, "accumuland = $(err.accumuland)")
println(io, "returned_value = $(err.returned_value)")

function add!!(x::Array{<:Any, N}, y::AbstractArray{<:Any, N}) where N
return x .+= y
if err.accumuland == err.returned_value
println(
io,
"Which in this case happenned to be equal. But they are not the same object."
)
end
end
165 changes: 121 additions & 44 deletions test/accumulation.jl
Original file line number Diff line number Diff line change
@@ -1,62 +1,139 @@
@testset "accumulation.jl" begin
@testset "scalar" begin
@test 16 == add!!(12, 4)
end
@testset "is_inplaceable_destination" begin
is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination

@test is_inplaceable_destination([1, 2, 3, 4])
@test !is_inplaceable_destination(1:4)

@test is_inplaceable_destination(Diagonal([1, 2, 3, 4]))
@test !is_inplaceable_destination(Diagonal(1:4))

@testset "Differentials" begin
@test 16 == add!!(12, @thunk(2*2))
@test 16 == add!!(16, Zero())
@test is_inplaceable_destination(view([1, 2, 3, 4], :, :))
@test !is_inplaceable_destination(view(1:4, :, :))

@test 16 == add!!(16, DoesNotExist()) # Should this be an error?
@test is_inplaceable_destination(falses(4))
@test is_inplaceable_destination(spzeros(4))
@test is_inplaceable_destination(spzeros(2, 2))

@test !is_inplaceable_destination(1.3)
@test !is_inplaceable_destination(@SVector [1, 2, 3])
@test !is_inplaceable_destination(Hermitian([1 2; 2 4]))
@test !is_inplaceable_destination(Symmetric([1 2; 2 4]))
end

@testset "Array" begin
@testset "Happy Path" begin
@testset "RHS Array" begin
A = [1.0 2.0; 3.0 4.0]
result = -1.0*ones(2,2)
ret = add!!(result, A)
@test ret === result # must be same object
@test result == [0.0 1.0; 2.0 3.0]
@testset "add!!" begin
@testset "scalar" begin
@test 16 == add!!(12, 4)
end

@testset "misc AbstractDifferential subtypes" begin
@test 16 == add!!(12, @thunk(2*2))
@test 16 == add!!(16, Zero())

@test 16 == add!!(16, DoesNotExist()) # Should this be an error?
end

@testset "add!!(::AbstractArray, ::AbstractArray)" begin
@testset "LHS Array (inplace)" begin
@testset "RHS Array" begin
A = [1.0 2.0; 3.0 4.0]
accumuland = -1.0*ones(2,2)
ret = add!!(accumuland, A)
@test ret === accumuland # must be same object
@test accumuland == [0.0 1.0; 2.0 3.0]
end

@testset "RHS StaticArray" begin
A = @SMatrix[1.0 2.0; 3.0 4.0]
accumuland = -1.0*ones(2,2)
ret = add!!(accumuland, A)
@test ret === accumuland # must be same object
@test accumuland == [0.0 1.0; 2.0 3.0]
end

@testset "RHS Diagonal" begin
A = Diagonal([1.0, 2.0])
accumuland = -1.0*ones(2,2)
ret = add!!(accumuland, A)
@test ret === accumuland # must be same object
@test accumuland == [0.0 -1.0; -1.0 1.0]
end
end

@testset "RHS StaticArray" begin
A = @SMatrix[1.0 2.0; 3.0 4.0]
result = -1.0*ones(2,2)
ret = add!!(result, A)
@test ret === result # must be same object
@test result == [0.0 1.0; 2.0 3.0]
@testset "add!!(::StaticArray, ::Array) (out of place)" begin
A = [1.0 2.0; 3.0 4.0]
accumuland = @SMatrix [-1.0 -1.0; -1.0 -1.0]
ret = add!!(accumuland, A)
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
@test ret !== accumuland # must not be same object
@test accumuland == [-1.0 -1.0; -1.0 -1.0] # must not have changed
end

@testset "RHS Diagonal" begin
@testset "add!!(::Diagonal{<:Vector}, ::Diagonal{<:Vector}) (inplace)" begin
A = Diagonal([1.0, 2.0])
result = -1.0*ones(2,2)
ret = add!!(result, A)
@test ret === result # must be same object
@test result == [0.0 -1.0; -1.0 1.0]
accumuland = Diagonal([-2.0, -2.0])
ret = add!!(accumuland, A)
@test ret === accumuland # must be same object
@test accumuland == Diagonal([-1.0, 0.0])
end

@testset "Unhappy Path" begin
# wrong length
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2))
# wrong shape
@test_throws DimensionMismatch add!!(ones(4,4), ones(16))
# wrong type (adding scalar to array)
@test_throws MethodError add!!(ones(4), 21.0)
end
end

@testset "InplaceableThunk" begin
ithunk = InplaceableThunk(
@thunk(-1.0*ones(2, 2)),
x -> x .-= ones(2, 2)
)

@testset "in place" begin
accumuland = [1.0 2.0; 3.0 4.0]
ret = add!!(accumuland, ithunk)
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
@test ret === accumuland # must be same object
end

@testset "out of place" begin
accumuland = @SMatrix [1.0 2.0; 3.0 4.0]

ret = add!!(accumuland, ithunk)
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer
@test ret !== accumuland # must not be same object
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated
end
end

@testset "Unhappy Path" begin
# wrong length
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2))
# wrong shape
@test_throws DimensionMismatch add!!(ones(4,4), ones(16))
# wrong type (adding scalar to array)
@test_throws MethodError add!!(ones(4), 21.0)
@testset "not actually inplace but said it was" begin
ithunk = InplaceableThunk(
@thunk(@assert false), # this should never be used in this test
x -> 77*ones(2, 2) # not actually inplace (also wrong)
)
accumuland = ones(2, 2)
@assert ChainRulesCore.debug_mode() == false
# without debug being enabled should return the result, not error
@test 77*ones(2, 2) == add!!(accumuland, ithunk)

ChainRulesCore.debug_mode() = true # enable debug mode
# with debug being enabled should error
@test_throws ChainRulesCore.BadInplaceException add!!(accumuland, ithunk)
ChainRulesCore.debug_mode() = false # disable it again
end
end

@testset "InplaceableThunk" begin
A=[1.0 2.0; 3.0 4.0]
ithunk = InplaceableThunk(
@thunk(A*B),
x -> x.+=A
)

result = -1.0*ones(2,2)
ret = add!!(result, ithunk)
@test ret === result # must be same object
@test result == [0.0 1.0; 2.0 3.0]
@testset "showerror BadInplaceException" begin
BadInplaceException = ChainRulesCore.BadInplaceException
ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing)
msg = sprint(showerror, BadInplaceException(ithunk, [22], [23]))
@test occursin("22", msg)

msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22]))
@test occursin("equal", msg_equal)
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
using Base.Broadcast: broadcastable
using BenchmarkTools
using ChainRulesCore
using LinearAlgebra: Diagonal, dot
using LinearAlgebra: Diagonal, dot, Hermitian, Symmetric
using StaticArrays
using SparseArrays
using Test

@testset "ChainRulesCore" begin
Expand Down