-
Notifications
You must be signed in to change notification settings - Fork 63
Make add!! decide if mutable or not #234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
0c46af9
move add!! test to on testset
oxinabox db9d5a7
some renaming
oxinabox 45e0013
make add!! responsible for deciding if to use InplaceThunk.add!
oxinabox 3e7cf88
document is_inplaceable_destination
oxinabox 353361e
Add debug mode for bad inplace
oxinabox 791f25d
small Improvements based on code review
oxinabox 187eb95
Disable inplace accumulation for Symmertic/Hermitian because JuliaLan…
oxinabox a3bd419
Import Hermitian, Symmetric in tests
oxinabox File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,90 @@ | |
|
||
Returns `x+y`, potentially mutating `x` in-place to hold this value. | ||
This avoids allocations when `x` can be mutated in this way. | ||
|
||
See also: [`InplaceableThunk`](@ref). | ||
""" | ||
add!!(x, y) = x + y | ||
|
||
add!!(x, t::InplaceableThunk) = t.add!(x) | ||
""" | ||
add!!(x, t::ImplacableThunk) | ||
|
||
The specialization of `add!!` for [`InplaceableThunk`](@ref) promises to only call | ||
`t.add!` on `x` if `x` is suitably mutable; otherwise it will be out of place. | ||
""" | ||
function add!!(x, t::InplaceableThunk) | ||
return if is_inplaceable_destination(x) | ||
if !debug_mode() | ||
t.add!(x) | ||
else | ||
debug_add!(x, t) | ||
end | ||
else | ||
x + t | ||
end | ||
end | ||
|
||
function add!!(x::AbstractArray{<:Any, N}, y::AbstractArray{<:Any, N}) where N | ||
return if is_inplaceable_destination(x) | ||
x .+= y | ||
else | ||
x + y | ||
end | ||
end | ||
|
||
|
||
""" | ||
is_inplaceable_destination(x) -> Bool | ||
|
||
Returns true if `x` is suitable for for storing inplace accumulation of gradients. | ||
For arrays this boils down `x .= y` if will work to mutate `x`, if `y` is an appropriate | ||
differential. | ||
Wrapper array types do not need to overload this if they overload `Base.parent`, and are | ||
`is_inplaceable_destination` if and only if their parent array is. | ||
Other types should overload this, as it defaults to `false`. | ||
""" | ||
is_inplaceable_destination(::Any) = false | ||
is_inplaceable_destination(::Array) = true | ||
is_inplaceable_destination(::SparseVector) = true | ||
is_inplaceable_destination(::SparseMatrixCSC) = true | ||
is_inplaceable_destination(::BitArray) = true | ||
function is_inplaceable_destination(x::AbstractArray) | ||
p = parent(x) | ||
p === x && return false # no parent | ||
# basically all wrapper types delegate `setindex!` to their `parent` after some | ||
# processing and so are mutable if their `parent` is. | ||
return is_inplaceable_destination(p) | ||
end | ||
|
||
# Hermitian and Symmetric are too fussy to deal with right now | ||
# https://github.com/JuliaLang/julia/issues/38056 | ||
# TODO: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/236 | ||
is_inplaceable_destination(::LinearAlgebra.Hermitian) = false | ||
is_inplaceable_destination(::LinearAlgebra.Symmetric) = false | ||
|
||
|
||
function debug_add!(accumuland, t::InplaceableThunk) | ||
returned_value = t.add!(accumuland) | ||
if returned_value !== accumuland | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps this is cheap enough to do it outside of |
||
throw(BadInplaceException(t, accumuland, returned_value)) | ||
end | ||
return returned_value | ||
end | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
struct BadInplaceException <: Exception | ||
ithunk::InplaceableThunk | ||
accumuland | ||
returned_value | ||
end | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function Base.showerror(io::IO, err::BadInplaceException) | ||
println(io, "`add!!(accumuland, ithunk))` did not return an updated accumuland.") | ||
println(io, "ithunk = $(err.ithunk)") | ||
println(io, "accumuland = $(err.accumuland)") | ||
println(io, "returned_value = $(err.returned_value)") | ||
|
||
function add!!(x::Array{<:Any, N}, y::AbstractArray{<:Any, N}) where N | ||
return x .+= y | ||
if err.accumuland == err.returned_value | ||
println( | ||
io, | ||
"Which in this case happenned to be equal. But they are not the same object." | ||
) | ||
end | ||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,139 @@ | ||
@testset "accumulation.jl" begin | ||
@testset "scalar" begin | ||
@test 16 == add!!(12, 4) | ||
end | ||
@testset "is_inplaceable_destination" begin | ||
is_inplaceable_destination = ChainRulesCore.is_inplaceable_destination | ||
|
||
@test is_inplaceable_destination([1, 2, 3, 4]) | ||
@test !is_inplaceable_destination(1:4) | ||
|
||
@test is_inplaceable_destination(Diagonal([1, 2, 3, 4])) | ||
@test !is_inplaceable_destination(Diagonal(1:4)) | ||
|
||
@testset "Differentials" begin | ||
@test 16 == add!!(12, @thunk(2*2)) | ||
@test 16 == add!!(16, Zero()) | ||
@test is_inplaceable_destination(view([1, 2, 3, 4], :, :)) | ||
@test !is_inplaceable_destination(view(1:4, :, :)) | ||
|
||
@test 16 == add!!(16, DoesNotExist()) # Should this be an error? | ||
@test is_inplaceable_destination(falses(4)) | ||
@test is_inplaceable_destination(spzeros(4)) | ||
@test is_inplaceable_destination(spzeros(2, 2)) | ||
|
||
@test !is_inplaceable_destination(1.3) | ||
@test !is_inplaceable_destination(@SVector [1, 2, 3]) | ||
@test !is_inplaceable_destination(Hermitian([1 2; 2 4])) | ||
@test !is_inplaceable_destination(Symmetric([1 2; 2 4])) | ||
end | ||
|
||
@testset "Array" begin | ||
@testset "Happy Path" begin | ||
@testset "RHS Array" begin | ||
A = [1.0 2.0; 3.0 4.0] | ||
result = -1.0*ones(2,2) | ||
ret = add!!(result, A) | ||
@test ret === result # must be same object | ||
@test result == [0.0 1.0; 2.0 3.0] | ||
@testset "add!!" begin | ||
@testset "scalar" begin | ||
@test 16 == add!!(12, 4) | ||
end | ||
|
||
@testset "misc AbstractDifferential subtypes" begin | ||
@test 16 == add!!(12, @thunk(2*2)) | ||
@test 16 == add!!(16, Zero()) | ||
|
||
@test 16 == add!!(16, DoesNotExist()) # Should this be an error? | ||
end | ||
|
||
@testset "add!!(::AbstractArray, ::AbstractArray)" begin | ||
@testset "LHS Array (inplace)" begin | ||
@testset "RHS Array" begin | ||
A = [1.0 2.0; 3.0 4.0] | ||
accumuland = -1.0*ones(2,2) | ||
ret = add!!(accumuland, A) | ||
@test ret === accumuland # must be same object | ||
@test accumuland == [0.0 1.0; 2.0 3.0] | ||
end | ||
|
||
@testset "RHS StaticArray" begin | ||
A = @SMatrix[1.0 2.0; 3.0 4.0] | ||
accumuland = -1.0*ones(2,2) | ||
ret = add!!(accumuland, A) | ||
@test ret === accumuland # must be same object | ||
@test accumuland == [0.0 1.0; 2.0 3.0] | ||
end | ||
|
||
@testset "RHS Diagonal" begin | ||
A = Diagonal([1.0, 2.0]) | ||
accumuland = -1.0*ones(2,2) | ||
ret = add!!(accumuland, A) | ||
@test ret === accumuland # must be same object | ||
@test accumuland == [0.0 -1.0; -1.0 1.0] | ||
end | ||
end | ||
|
||
@testset "RHS StaticArray" begin | ||
A = @SMatrix[1.0 2.0; 3.0 4.0] | ||
result = -1.0*ones(2,2) | ||
ret = add!!(result, A) | ||
@test ret === result # must be same object | ||
@test result == [0.0 1.0; 2.0 3.0] | ||
@testset "add!!(::StaticArray, ::Array) (out of place)" begin | ||
A = [1.0 2.0; 3.0 4.0] | ||
accumuland = @SMatrix [-1.0 -1.0; -1.0 -1.0] | ||
ret = add!!(accumuland, A) | ||
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer | ||
@test ret !== accumuland # must not be same object | ||
@test accumuland == [-1.0 -1.0; -1.0 -1.0] # must not have changed | ||
end | ||
|
||
@testset "RHS Diagonal" begin | ||
@testset "add!!(::Diagonal{<:Vector}, ::Diagonal{<:Vector}) (inplace)" begin | ||
A = Diagonal([1.0, 2.0]) | ||
result = -1.0*ones(2,2) | ||
ret = add!!(result, A) | ||
@test ret === result # must be same object | ||
@test result == [0.0 -1.0; -1.0 1.0] | ||
accumuland = Diagonal([-2.0, -2.0]) | ||
ret = add!!(accumuland, A) | ||
@test ret === accumuland # must be same object | ||
@test accumuland == Diagonal([-1.0, 0.0]) | ||
end | ||
|
||
@testset "Unhappy Path" begin | ||
# wrong length | ||
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) | ||
# wrong shape | ||
@test_throws DimensionMismatch add!!(ones(4,4), ones(16)) | ||
# wrong type (adding scalar to array) | ||
@test_throws MethodError add!!(ones(4), 21.0) | ||
end | ||
end | ||
|
||
@testset "InplaceableThunk" begin | ||
ithunk = InplaceableThunk( | ||
@thunk(-1.0*ones(2, 2)), | ||
x -> x .-= ones(2, 2) | ||
) | ||
|
||
@testset "in place" begin | ||
accumuland = [1.0 2.0; 3.0 4.0] | ||
ret = add!!(accumuland, ithunk) | ||
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer | ||
@test ret === accumuland # must be same object | ||
end | ||
|
||
@testset "out of place" begin | ||
accumuland = @SMatrix [1.0 2.0; 3.0 4.0] | ||
|
||
ret = add!!(accumuland, ithunk) | ||
@test ret == [0.0 1.0; 2.0 3.0] # must return right answer | ||
@test ret !== accumuland # must not be same object | ||
@test accumuland == [1.0 2.0; 3.0 4.0] # must not have mutated | ||
end | ||
end | ||
|
||
@testset "Unhappy Path" begin | ||
# wrong length | ||
@test_throws DimensionMismatch add!!(ones(4,4), ones(2,2)) | ||
# wrong shape | ||
@test_throws DimensionMismatch add!!(ones(4,4), ones(16)) | ||
# wrong type (adding scalar to array) | ||
@test_throws MethodError add!!(ones(4), 21.0) | ||
@testset "not actually inplace but said it was" begin | ||
ithunk = InplaceableThunk( | ||
@thunk(@assert false), # this should never be used in this test | ||
x -> 77*ones(2, 2) # not actually inplace (also wrong) | ||
) | ||
accumuland = ones(2, 2) | ||
@assert ChainRulesCore.debug_mode() == false | ||
# without debug being enabled should return the result, not error | ||
@test 77*ones(2, 2) == add!!(accumuland, ithunk) | ||
|
||
ChainRulesCore.debug_mode() = true # enable debug mode | ||
# with debug being enabled should error | ||
@test_throws ChainRulesCore.BadInplaceException add!!(accumuland, ithunk) | ||
ChainRulesCore.debug_mode() = false # disable it again | ||
end | ||
end | ||
|
||
@testset "InplaceableThunk" begin | ||
A=[1.0 2.0; 3.0 4.0] | ||
ithunk = InplaceableThunk( | ||
@thunk(A*B), | ||
x -> x.+=A | ||
) | ||
|
||
result = -1.0*ones(2,2) | ||
ret = add!!(result, ithunk) | ||
@test ret === result # must be same object | ||
@test result == [0.0 1.0; 2.0 3.0] | ||
@testset "showerror BadInplaceException" begin | ||
BadInplaceException = ChainRulesCore.BadInplaceException | ||
ithunk = InplaceableThunk(@thunk(@assert false), x̄->nothing) | ||
msg = sprint(showerror, BadInplaceException(ithunk, [22], [23])) | ||
@test occursin("22", msg) | ||
|
||
msg_equal = sprint(showerror, BadInplaceException(ithunk, [22], [22])) | ||
@test occursin("equal", msg_equal) | ||
end | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this true? this suggest to me it'll print something like
But as far as i can see that's not in the error message, or did i miss it (or misread this comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually documenting something older.
Its unrelated to the main content of this PR.
Not inplace add but simply
+
Its this message
ChainRulesCore.jl/src/differentials/composite.jl
Lines 269 to 284 in a3e76b1
I just realized when writing this this PR that I there was no list of what Debug Mode did,
so I couldn't add the Inplace behavour to it.