Skip to content

Commit 6f5cd3f

Browse files
authored
Merge pull request #70 from JuliaDiff/ox/smartadd
Fix tests now that add!! is too smart for them
2 parents d369e3f + 0b2b2b1 commit 6f5cd3f

File tree

2 files changed

+52
-69
lines changed

2 files changed

+52
-69
lines changed

test/check_result.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66
check(11.0, Zero())
77
check([10.0, 20.0], @thunk([2.0, 0.0]))
88

9-
# These `InplaceableThunk`s aren't actually inplace, but that's ok.
10-
check(12.0, InplaceableThunk(@thunk(2.0), X̄ ->+ 2.0))
11-
12-
@test fails(()->check(12.0, InplaceableThunk(@thunk(2.0), X̄ ->+ 3.0)))
9+
check(12.0, InplaceableThunk(@thunk(2.0), X̄ -> error("Should not have in-placed")))
1310

1411
check([10.0, 20.0], InplaceableThunk(
1512
@thunk([2.0, 0.0]),

test/testers.jl

Lines changed: 51 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ primalapprox(x) = x
2323
end
2424

2525
@testset "unary: identity(x)" begin
26-
function ChainRulesCore.frule((_, ), ::typeof(identity), x)
27-
return x,
26+
function ChainRulesCore.frule((_, ), ::typeof(identity), x)
27+
return x,
2828
end
2929
function ChainRulesCore.rrule(::typeof(identity), x)
30-
function identity_pullback()
31-
return (NO_FIELDS, )
30+
function identity_pullback(ȳ)
31+
return (NO_FIELDS, ȳ)
3232
end
3333
return x, identity_pullback
3434
end
@@ -42,62 +42,48 @@ primalapprox(x) = x
4242
end
4343
end
4444

45-
@testset "Inplace accumumulation: first on Array" begin
45+
@testset "Inplace accumulation: identity on Array" begin
4646
@testset "Correct definitions" begin
47-
function ChainRulesCore.frule((_, ẋ), ::typeof(first), x::Array)
48-
= InplaceableThunk(
49-
@thunk(first(ẋ)),
50-
->+ first(ẋ), # This won't actually happen inplace
51-
)
52-
return first(x), ẏ
47+
local inplace_used
48+
function ChainRulesCore.frule((_, ẋ), ::typeof(identity), x::Array)
49+
= InplaceableThunk(@thunk(ẋ), ȧ -> (inplace_used=true; ȧ .+= ẋ))
50+
return identity(x), ẏ
5351
end
54-
function ChainRulesCore.rrule(::typeof(first), x::Array{T}) where T
55-
x_dims = size(x)
56-
function first_pullback(ȳ)
57-
x̄_ret = InplaceableThunk(
58-
Thunk() do
59-
= zeros(T, x_dims)
60-
x̄[1]=
61-
62-
end,
63-
-> (ā[1] += ȳ; ā)
64-
)
52+
function ChainRulesCore.rrule(::typeof(identity), x::Array)
53+
function identity_pullback(ȳ)
54+
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> (inplace_used=true; ā .+= ȳ))
6555
return (NO_FIELDS, x̄_ret)
6656
end
67-
return first(x), first_pullback
57+
return identity(x), identity_pullback
6858
end
6959

70-
frule_test(first, (randn(4), randn(4)))
71-
rrule_test(first, randn(), (randn(4), randn(4)))
60+
inplace_used = false
61+
frule_test(identity, (randn(4), randn(4)))
62+
@test inplace_used # make sure we are using, and thus testing the add!
63+
64+
inplace_used = false
65+
rrule_test(identity, randn(4), (randn(4), randn(4)))
66+
@test inplace_used # make sure we are using, and thus testing the add!
7267
end
7368

74-
@testset "Incorrect inplace definitions" begin
75-
my_first(value) = first(value) # we are going to define bad rules on this
76-
function ChainRulesCore.frule((_, ẋ), ::typeof(my_first), x::Array)
77-
= InplaceableThunk(
78-
@thunk(first(ẋ)), # correct
79-
->+ 1000*first(ẋ), # incorrect (also not actually inplace)
80-
)
81-
return first(x), ẏ
69+
@testset "Incorrect in-place definitions" begin
70+
my_identity(value) = value # we will define bad rules on this
71+
function ChainRulesCore.frule((_, ẋ), ::typeof(my_identity), x::Array)
72+
# only the in-place part is incorrect
73+
= InplaceableThunk(@thunk(ẋ), ȧ -> ȧ .+= 200 .* ẋ)
74+
return my_identity(x), ẏ
8275
end
83-
function ChainRulesCore.rrule(::typeof(my_first), x::Array{T}) where T
76+
function ChainRulesCore.rrule(::typeof(my_identity), x::Array)
8477
x_dims = size(x)
85-
function my_first_pullback(ȳ)
86-
x̄_ret = InplaceableThunk(
87-
Thunk() do # correct
88-
= zeros(T, x_dims)
89-
x̄[1]=
90-
91-
end,
92-
-> (ā[1] += 1000*ȳ; ā) # incorrect
93-
)
78+
function my_identity_pullback(ȳ)
79+
# only the in-place part is incorrect
80+
x̄_ret = InplaceableThunk(@thunk(ȳ), ā -> ā .+= 200 .* ȳ)
9481
return (NO_FIELDS, x̄_ret)
9582
end
96-
return first(x), my_first_pullback
83+
return my_identity(x), my_identity_pullback
9784
end
98-
99-
@test fails(()->frule_test(my_first, (randn(4), randn(4))))
100-
@test fails(()->rrule_test(my_first, randn(), (randn(4), randn(4))))
85+
@test fails(()->frule_test(my_identity, (randn(4), randn(4))))
86+
@test fails(()->rrule_test(my_identity, randn(4), (randn(4), randn(4))))
10187
end
10288
end
10389

@@ -141,9 +127,9 @@ primalapprox(x) = x
141127
simo_pullback((a, b)) = (NO_FIELDS, a .+ 2 .* b)
142128
return simo(x), simo_pullback
143129
end
144-
function ChainRulesCore.frule((_, ), simo, x)
130+
function ChainRulesCore.frule((_, ), simo, x)
145131
y = simo(x)
146-
return y, Composite{typeof(y)}(ẋ, 2)
132+
return y, Composite{typeof(y)}(ẋ, 2)
147133
end
148134

149135
@testset "frule_test" begin
@@ -198,8 +184,8 @@ primalapprox(x) = x
198184
end
199185

200186
@testset "unary with kwargs: futestkws(x; err)" begin
201-
function ChainRulesCore.frule((_, ), ::typeof(futestkws), x; err = true)
202-
return futestkws(x; err = err),
187+
function ChainRulesCore.frule((_, ), ::typeof(futestkws), x; err = true)
188+
return futestkws(x; err = err),
203189
end
204190
function ChainRulesCore.rrule(::typeof(futestkws), x; err = true)
205191
function futestkws_pullback(Δx)
@@ -232,8 +218,8 @@ primalapprox(x) = x
232218
end
233219

234220
@testset "binary with kwargs: fbtestkws(x, y; err)" begin
235-
function ChainRulesCore.frule((_, , _), ::typeof(fbtestkws), x, y; err = true)
236-
return fbtestkws(x, y; err = err),
221+
function ChainRulesCore.frule((_, , _), ::typeof(fbtestkws), x, y; err = true)
222+
return fbtestkws(x, y; err = err),
237223
end
238224
function ChainRulesCore.rrule(::typeof(fbtestkws), x, y; err = true)
239225
function fbtestkws_pullback(Δx)
@@ -323,26 +309,26 @@ primalapprox(x) = x
323309
return iterfun(iter), iterfun_pullback
324310
end
325311

326-
# This needs to be in a seperate testet to stop the `x` being shared with `iterfun`
312+
# This needs to be in a separate testet to stop the `x` being shared with `iterfun`
327313
@testset "Testing iterator function" begin
328314
x = TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
329-
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
315+
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
330316
= TestIterator(randn(2, 3), Base.SizeUnknown(), Base.EltypeUnknown())
331317

332-
frule_test(iterfun, (x, ))
318+
frule_test(iterfun, (x, ))
333319
rrule_test(iterfun, randn(), (x, x̄))
334320
end
335321
end
336322

337323
@testset "unhappy path" begin
338324
@testset "primal wrong" begin
339325
my_identity1(x) = x
340-
function ChainRulesCore.frule((_, ), ::typeof(my_identity1), x)
341-
return 2.5 * x,
326+
function ChainRulesCore.frule((_, ), ::typeof(my_identity1), x)
327+
return 2.5 * x,
342328
end
343329
function ChainRulesCore.rrule(::typeof(my_identity1), x)
344-
function identity_pullback()
345-
return (NO_FIELDS, )
330+
function identity_pullback(ȳ)
331+
return (NO_FIELDS, ȳ)
346332
end
347333
return 2.5 * x, identity_pullback
348334
end
@@ -351,14 +337,14 @@ primalapprox(x) = x
351337
@test fails(()->rrule_test(my_identity1, 4.1, (2.2, 3.3)))
352338
end
353339

354-
@testset "deriviative wrong" begin
340+
@testset "derivative wrong" begin
355341
my_identity2(x) = x
356-
function ChainRulesCore.frule((_, ), ::typeof(my_identity2), x)
357-
return x, 2.7 *
342+
function ChainRulesCore.frule((_, ), ::typeof(my_identity2), x)
343+
return x, 2.7 *
358344
end
359345
function ChainRulesCore.rrule(::typeof(my_identity2), x)
360-
function identity_pullback()
361-
return (NO_FIELDS, 31.8 * )
346+
function identity_pullback(ȳ)
347+
return (NO_FIELDS, 31.8 * ȳ)
362348
end
363349
return x, identity_pullback
364350
end

0 commit comments

Comments
 (0)