diff --git a/Project.toml b/Project.toml index c70258691..7700e37cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.16" +version = "0.9.17" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] BenchmarkTools = "0.5" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index c9cb646e3..686c556fc 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -10,10 +10,10 @@ uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" version = "0.5.10" [[ChainRulesCore]] -deps = ["LinearAlgebra", "MuladdMacro"] +deps = ["LinearAlgebra", "MuladdMacro", "SparseArrays"] path = ".." uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.12" +version = "0.9.17" [[Dates]] deps = ["Printf"] @@ -81,10 +81,10 @@ uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" version = "0.2.2" [[Parsers]] -deps = ["Dates", "Test"] -git-tree-sha1 = "8077624b3c450b15c087944363606a6ba12f925e" +deps = ["Dates"] +git-tree-sha1 = "6fa4202675c05ba0f8268a6ddf07606350eda3ce" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.10" +version = "1.0.11" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -117,6 +117,10 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/docs/src/api.md b/docs/src/api.md index 3a4d03a64..3a698efa1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -28,10 +28,9 @@ Private = false ``` ## Accumulation -```@autodocs -Modules = [ChainRulesCore] -Pages = ["accumulation.jl"] -Private = false +```@docs +add!! +ChainRulesCore.is_inplaceable_destination ``` ## Ruleset Loading diff --git a/docs/src/debug_mode.md b/docs/src/debug_mode.md index a6486c9d3..64a50ad43 100644 --- a/docs/src/debug_mode.md +++ b/docs/src/debug_mode.md @@ -11,3 +11,8 @@ To enable, redefine the [`ChainRulesCore.debug_mode`](@ref) function to return ` ```julia ChainRulesCore.debug_mode() = true ``` + +## Features of Debug Mode: + + - If you add a `Composite` to a primal value, and it was unable to construct a new primal values, then a better error message will be displayed detailing what overloads need to be written to fix this. + - during [`add!!`](@ref), if an `InplaceThunk` is used, and it runs the code that is supposed to run in place, but the return result is not the input (with updated values), then an error is thrown. Rather than silently using what ever values were returned. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 2eeb6739a..400a34540 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -1,6 +1,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using LinearAlgebra: LinearAlgebra +using SparseArrays: SparseVector, SparseMatrixCSC using MuladdMacro: @muladd export on_new_rule, refresh_rules # generation tools diff --git a/src/accumulation.jl b/src/accumulation.jl index da5865706..291b43885 100644 --- a/src/accumulation.jl +++ b/src/accumulation.jl @@ -3,13 +3,90 @@ Returns `x+y`, potentially mutating `x` in-place to hold this value. This avoids allocations when `x` can be mutated in this way. - -See also: [`InplaceableThunk`](@ref). """ add!!(x, y) = x + y -add!!(x, t::InplaceableThunk) = t.add!(x) +""" + add!!(x, t::ImplacableThunk) + +The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only call +`t.add!` on `x` if `x` is suitably mutable; otherwise it will be out of place. +""" +function add!!(x, t::InplaceableThunk) + return if is_inplaceable_destination(x) + if !debug_mode() + t.add!(x) + else + debug_add!(x, t) + end + else + x + t + end +end + +function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N + return if is_inplaceable_destination(x) + x .+= y + else + x + y + end +end + + +""" + is_inplaceable_destination(x) -> Bool + +Returns true if `x` is suitable for for storing inplace accumulation of gradients. +For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate +differential. +Wrapper array types do not need to overload this if they overload `Base.parent`, and are +`is_inplaceable_destination` if and only if their parent array is. +Other types should overload this, as it defaults to `false`. +""" +is_inplaceable_destination(::Any) = false +is_inplaceable_destination(::Array) = true +is_inplaceable_destination(::SparseVector) = true +is_inplaceable_destination(::SparseMatrixCSC) = true +is_inplaceable_destination(::BitArray) = true +function is_inplaceable_destination(x::AbstractArray) + p = parent(x) + p === x && return false # no parent + # basically all wrapper types delegate `setindex!` to their `parent` after some + # processing and so are mutable if their `parent` is. + return is_inplaceable_destination(p) +end + +# Hermitian and Symmetric are too fussy to deal with right now +# https://github.com/JuliaLang/julia/issues/38056 +# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236 +is_inplaceable_destination(::LinearAlgebra.Hermitian) = false +is_inplaceable_destination(::LinearAlgebra.Symmetric) = false + + +function debug_add!(accumuland, t::InplaceableThunk) + returned_value = t.add!(accumuland) + if returned_value !== accumuland + throw(BadInplaceException(t, accumuland, returned_value)) + end + return returned_value +end + +struct BadInplaceException <: Exception + ithunk::InplaceableThunk + accumuland + returned_value +end + +function Base.showerror(io::IO, err::BadInplaceException) + println(io, "`add!!(accumuland, ithunk))` did not return an updated accumuland.") + println(io, "ithunk = $(err.ithunk)") + println(io, "accumuland = $(err.accumuland)") + println(io, "returned_value = $(err.returned_value)") -function add!!(x::Array{<:Any, N}, y::AbstractArray{<:Any, N}) where N - return x .+= y + if err.accumuland == err.returned_value + println( + io, + "Which in this case happenned to be equal. But they are not the same object." + ) + end end diff --git a/test/accumulation.jl b/test/accumulation.jl index e605b2f9d..8b711d523 100644 --- a/test/accumulation.jl +++ b/test/accumulation.jl @@ -1,62 +1,139 @@ @testset "accumulation.jl" begin - @testset "scalar" begin - @test 16 == add!!(12, 4) - end + @testset "is_inplaceable_destination" begin + is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination + + @test is_inplaceable_destination([1, 2, 3, 4]) + @test !is_inplaceable_destination(1:4) + + @test is_inplaceable_destination(Diagonal([1, 2, 3, 4])) + @test !is_inplaceable_destination(Diagonal(1:4)) - @testset "Differentials" begin - @test 16 == add!!(12, @thunk(2*2)) - @test 16 == add!!(16, Zero()) + @test is_inplaceable_destination(view([1, 2, 3, 4], :, :)) + @test !is_inplaceable_destination(view(1:4, :, :)) - @test 16 == add!!(16, DoesNotExist()) # Should this be an error? + @test is_inplaceable_destination(falses(4)) + @test is_inplaceable_destination(spzeros(4)) + @test is_inplaceable_destination(spzeros(2, 2)) + + @test !is_inplaceable_destination(1.3) + @test !is_inplaceable_destination(@SVector [1, 2, 3]) + @test !is_inplaceable_destination(Hermitian([1 2; 2 4])) + @test !is_inplaceable_destination(Symmetric([1 2; 2 4])) end - @testset "Array" begin - @testset "Happy Path" begin - @testset "RHS Array" begin - A = [1.0 2.0; 3.0 4.0] - result = -1.0*ones(2,2) - ret = add!!(result, A) - @test ret === result # must be same object - @test result == [0.0 1.0; 2.0 3.0] + @testset "add!!" begin + @testset "scalar" begin + @test 16 == add!!(12, 4) + end + + @testset "misc AbstractDifferential subtypes" begin + @test 16 == add!!(12, @thunk(2*2)) + @test 16 == add!!(16, Zero()) + + @test 16 == add!!(16, DoesNotExist()) # Should this be an error? + end + + @testset "add!!(::AbstractArray, ::AbstractArray)" begin + @testset "LHS Array (inplace)" begin + @testset "RHS Array" begin + A = [1.0 2.0; 3.0 4.0] + accumuland = -1.0*ones(2,2) + ret = add!!(accumuland, A) + @test ret === accumuland # must be same object + @test accumuland == [0.0 1.0; 2.0 3.0] + end + + @testset "RHS StaticArray" begin + A = @SMatrix[1.0 2.0; 3.0 4.0] + accumuland = -1.0*ones(2,2) + ret = add!!(accumuland, A) + @test ret === accumuland # must be same object + @test accumuland == [0.0 1.0; 2.0 3.0] + end + + @testset "RHS Diagonal" begin + A = Diagonal([1.0, 2.0]) + accumuland = -1.0*ones(2,2) + ret = add!!(accumuland, A) + @test ret === accumuland # must be same object + @test accumuland == [0.0 -1.0; -1.0 1.0] + end end - @testset "RHS StaticArray" begin - A = @SMatrix[1.0 2.0; 3.0 4.0] - result = -1.0*ones(2,2) - ret = add!!(result, A) - @test ret === result # must be same object - @test result == [0.0 1.0; 2.0 3.0] + @testset "add!!(::StaticArray, ::Array) (out of place)" begin + A = [1.0 2.0; 3.0 4.0] + accumuland = @SMatrix [-1.0 -1.0; -1.0 -1.0] + ret = add!!(accumuland, A) + @test ret == [0.0 1.0; 2.0 3.0] # must return right answer + @test ret !== accumuland # must not be same object + @test accumuland == [-1.0 -1.0; -1.0 -1.0] # must not have changed end - @testset "RHS Diagonal" begin + @testset "add!!(::Diagonal{<:Vector}, ::Diagonal{<:Vector}) (inplace)" begin A = Diagonal([1.0, 2.0]) - result = -1.0*ones(2,2) - ret = add!!(result, A) - @test ret === result # must be same object - @test result == [0.0 -1.0; -1.0 1.0] + accumuland = Diagonal([-2.0, -2.0]) + ret = add!!(accumuland, A) + @test ret === accumuland # must be same object + @test accumuland == Diagonal([-1.0, 0.0]) + end + + @testset "Unhappy Path" begin + # wrong length + @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) + # wrong shape + @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) + # wrong type (adding scalar to array) + @test_throws MethodError add!!(ones(4), 21.0) + end + end + + @testset "InplaceableThunk" begin + ithunk = InplaceableThunk( + @thunk(-1.0*ones(2, 2)), + x -> x .-= ones(2, 2) + ) + + @testset "in place" begin + accumuland = [1.0 2.0; 3.0 4.0] + ret = add!!(accumuland, ithunk) + @test ret == [0.0 1.0; 2.0 3.0] # must return right answer + @test ret === accumuland # must be same object + end + + @testset "out of place" begin + accumuland = @SMatrix [1.0 2.0; 3.0 4.0] + + ret = add!!(accumuland, ithunk) + @test ret == [0.0 1.0; 2.0 3.0] # must return right answer + @test ret !== accumuland # must not be same object + @test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated end end - @testset "Unhappy Path" begin - # wrong length - @test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) - # wrong shape - @test_throws DimensionMismatch add!!(ones(4,4), ones(16)) - # wrong type (adding scalar to array) - @test_throws MethodError add!!(ones(4), 21.0) + @testset "not actually inplace but said it was" begin + ithunk = InplaceableThunk( + @thunk(@assert false), # this should never be used in this test + x -> 77*ones(2, 2) # not actually inplace (also wrong) + ) + accumuland = ones(2, 2) + @assert ChainRulesCore.debug_mode() == false + # without debug being enabled should return the result, not error + @test 77*ones(2, 2) == add!!(accumuland, ithunk) + + ChainRulesCore.debug_mode() = true # enable debug mode + # with debug being enabled should error + @test_throws ChainRulesCore.BadInplaceException add!!(accumuland, ithunk) + ChainRulesCore.debug_mode() = false # disable it again end end - @testset "InplaceableThunk" begin - A=[1.0 2.0; 3.0 4.0] - ithunk = InplaceableThunk( - @thunk(A*B), - x -> x.+=A - ) - - result = -1.0*ones(2,2) - ret = add!!(result, ithunk) - @test ret === result # must be same object - @test result == [0.0 1.0; 2.0 3.0] + @testset "showerror BadInplaceException" begin + BadInplaceException = ChainRulesCore.BadInplaceException + ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing) + msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) + @test occursin("22", msg) + + msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22])) + @test occursin("equal", msg_equal) end end diff --git a/test/runtests.jl b/test/runtests.jl index 876188823..eab854b10 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,9 @@ using Base.Broadcast: broadcastable using BenchmarkTools using ChainRulesCore -using LinearAlgebra: Diagonal, dot +using LinearAlgebra: Diagonal, dot, Hermitian, Symmetric using StaticArrays +using SparseArrays using Test @testset "ChainRulesCore" begin