Skip to content

Fix tests now that add!! is too smart for them #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions test/check_result.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
check(11.0, Zero())
check([10.0, 20.0], @thunk([2.0, 0.0]))

# These `InplaceableThunk`s aren't actually inplace, but that's ok.
check(12.0, InplaceableThunk(@thunk(2.0), X̄ -> X̄ + 2.0))

@test fails(()->check(12.0, InplaceableThunk(@thunk(2.0), X̄ -> X̄ + 3.0)))
check(12.0, InplaceableThunk(@thunk(2.0), X̄ -> error("Should not have in-placed")))

check([10.0, 20.0], InplaceableThunk(
@thunk([2.0, 0.0]),
Expand Down
116 changes: 51 additions & 65 deletions test/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ primalapprox(x) = x
end

@testset "unary: identity(x)" begin
function ChainRulesCore.frule((_, ), ::typeof(identity), x)
return x,
function ChainRulesCore.frule((_, ), ::typeof(identity), x)
return x,
end
function ChainRulesCore.rrule(::typeof(identity), x)
function identity_pullback()
return (NO_FIELDS, )
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
end
return x, identity_pullback
end
Expand All @@ -42,62 +42,48 @@ primalapprox(x) = x
end
end

@testset "Inplace accumumulation: first on Array" begin
@testset "Inplace accumulation: identity on Array" begin
@testset "Correct definitions" begin
function ChainRulesCore.frule((_, ẋ), ::typeof(first), x::Array)
ẏ = InplaceableThunk(
@thunk(first(ẋ)),
ȧ -> ȧ + first(ẋ), # This won't actually happen inplace
)
return first(x), ẏ
local inplace_used
function ChainRulesCore.frule((_, ẋ), ::typeof(identity), x::Array)
ẏ = InplaceableThunk(@thunk(ẋ), ȧ -> (inplace_used=true; ȧ .+= ẋ))
return identity(x), ẏ
end
function ChainRulesCore.rrule(::typeof(first), x::Array{T}) where T
x_dims = size(x)
function first_pullback(ȳ)
x̄_ret = InplaceableThunk(
Thunk() do
x̄ = zeros(T, x_dims)
x̄[1]=ȳ
end,
ā -> (ā[1] += ȳ; ā)
)
function ChainRulesCore.rrule(::typeof(identity), x::Array)
function identity_pullback(ȳ)
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> (inplace_used=true; ā .+= ȳ))
return (NO_FIELDS, x̄_ret)
end
return first(x), first_pullback
return identity(x), identity_pullback
end

frule_test(first, (randn(4), randn(4)))
rrule_test(first, randn(), (randn(4), randn(4)))
inplace_used = false
frule_test(identity, (randn(4), randn(4)))
@test inplace_used # make sure we are using, and thus testing the add!

inplace_used = false
rrule_test(identity, randn(4), (randn(4), randn(4)))
@test inplace_used # make sure we are using, and thus testing the add!
end

@testset "Incorrect inplace definitions" begin
my_first(value) = first(value) # we are going to define bad rules on this
function ChainRulesCore.frule((_, ẋ), ::typeof(my_first), x::Array)
ẏ = InplaceableThunk(
@thunk(first(ẋ)), # correct
ȧ -> ȧ + 1000*first(ẋ), # incorrect (also not actually inplace)
)
return first(x), ẏ
@testset "Incorrect in-place definitions" begin
my_identity(value) = value # we will define bad rules on this
function ChainRulesCore.frule((_, ẋ), ::typeof(my_identity), x::Array)
# only the in-place part is incorrect
ẏ = InplaceableThunk(@thunk(ẋ), ȧ -> ȧ .+= 200 .* ẋ)
return my_identity(x), ẏ
end
function ChainRulesCore.rrule(::typeof(my_first), x::Array{T}) where T
function ChainRulesCore.rrule(::typeof(my_identity), x::Array)
x_dims = size(x)
function my_first_pullback(ȳ)
x̄_ret = InplaceableThunk(
Thunk() do # correct
x̄ = zeros(T, x_dims)
x̄[1]=ȳ
end,
ā -> (ā[1] += 1000*ȳ; ā) # incorrect
)
function my_identity_pullback(ȳ)
# only the in-place part is incorrect
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ)
return (NO_FIELDS, x̄_ret)
end
return first(x), my_first_pullback
return my_identity(x), my_identity_pullback
end

@test fails(()->frule_test(my_first, (randn(4), randn(4))))
@test fails(()->rrule_test(my_first, randn(), (randn(4), randn(4))))
@test fails(()->frule_test(my_identity, (randn(4), randn(4))))
@test fails(()->rrule_test(my_identity, randn(4), (randn(4), randn(4))))
end
end

Expand Down Expand Up @@ -141,9 +127,9 @@ primalapprox(x) = x
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
return simo(x), simo_pullback
end
function ChainRulesCore.frule((_, ), simo, x)
function ChainRulesCore.frule((_, ), simo, x)
y = simo(x)
return y, Composite{typeof(y)}(ẋ, 2ẋ)
return y, Composite{typeof(y)}(ẋ, 2ẋ)
end

@testset "frule_test" begin
Expand Down Expand Up @@ -198,8 +184,8 @@ primalapprox(x) = x
end

@testset "unary with kwargs: futestkws(x; err)" begin
function ChainRulesCore.frule((_, ), ::typeof(futestkws), x; err = true)
return futestkws(x; err = err),
function ChainRulesCore.frule((_, ), ::typeof(futestkws), x; err = true)
return futestkws(x; err = err),
end
function ChainRulesCore.rrule(::typeof(futestkws), x; err = true)
function futestkws_pullback(Δx)
Expand Down Expand Up @@ -232,8 +218,8 @@ primalapprox(x) = x
end

@testset "binary with kwargs: fbtestkws(x, y; err)" begin
function ChainRulesCore.frule((_, , _), ::typeof(fbtestkws), x, y; err = true)
return fbtestkws(x, y; err = err),
function ChainRulesCore.frule((_, , _), ::typeof(fbtestkws), x, y; err = true)
return fbtestkws(x, y; err = err),
end
function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err = true)
function fbtestkws_pullback(Δx)
Expand Down Expand Up @@ -323,26 +309,26 @@ primalapprox(x) = x
return iterfun(iter), iterfun_pullback
end

# This needs to be in a seperate testet to stop the `x` being shared with `iterfun`
# This needs to be in a separate testet to stop the `x` being shared with `iterfun`
@testset "Testing iterator function" begin
x = TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
x̄ = TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())

frule_test(iterfun, (x, ))
frule_test(iterfun, (x, ))
rrule_test(iterfun, randn(), (x, x̄))
end
end

@testset "unhappy path" begin
@testset "primal wrong" begin
my_identity1(x) = x
function ChainRulesCore.frule((_, ), ::typeof(my_identity1), x)
return 2.5 * x,
function ChainRulesCore.frule((_, ), ::typeof(my_identity1), x)
return 2.5 * x,
end
function ChainRulesCore.rrule(::typeof(my_identity1), x)
function identity_pullback()
return (NO_FIELDS, )
function identity_pullback(ȳ)
return (NO_FIELDS, ȳ)
end
return 2.5 * x, identity_pullback
end
Expand All @@ -351,14 +337,14 @@ primalapprox(x) = x
@test fails(()->rrule_test(my_identity1, 4.1, (2.2, 3.3)))
end

@testset "deriviative wrong" begin
@testset "derivative wrong" begin
my_identity2(x) = x
function ChainRulesCore.frule((_, ), ::typeof(my_identity2), x)
return x, 2.7 *
function ChainRulesCore.frule((_, ), ::typeof(my_identity2), x)
return x, 2.7 *
end
function ChainRulesCore.rrule(::typeof(my_identity2), x)
function identity_pullback()
return (NO_FIELDS, 31.8 * )
function identity_pullback(ȳ)
return (NO_FIELDS, 31.8 * ȳ)
end
return x, identity_pullback
end
Expand Down