@@ -23,12 +23,12 @@ primalapprox(x) = x
23
23
end
24
24
25
25
@testset " unary: identity(x)" begin
26
- function ChainRulesCore. frule ((_, ẏ ), :: typeof (identity), x)
27
- return x, ẏ
26
+ function ChainRulesCore. frule ((_, ẏ ), :: typeof (identity), x)
27
+ return x, ẏ
28
28
end
29
29
function ChainRulesCore. rrule (:: typeof (identity), x)
30
- function identity_pullback (ȳ )
31
- return (NO_FIELDS, ȳ )
30
+ function identity_pullback (ȳ )
31
+ return (NO_FIELDS, ȳ )
32
32
end
33
33
return x, identity_pullback
34
34
end
@@ -42,62 +42,48 @@ primalapprox(x) = x
42
42
end
43
43
end
44
44
45
- @testset " Inplace accumumulation: first on Array" begin
45
+ @testset " Inplace accumulation: identity on Array" begin
46
46
@testset " Correct definitions" begin
47
- function ChainRulesCore. frule ((_, ẋ), :: typeof (first), x:: Array )
48
- ẏ = InplaceableThunk (
49
- @thunk (first (ẋ)),
50
- ȧ -> ȧ + first (ẋ), # This won't actually happen inplace
51
- )
52
- return first (x), ẏ
47
+ local inplace_used
48
+ function ChainRulesCore. frule ((_, ẋ), :: typeof (identity), x:: Array )
49
+ ẏ = InplaceableThunk (@thunk (ẋ), ȧ -> (inplace_used= true ; ȧ .+ = ẋ))
50
+ return identity (x), ẏ
53
51
end
54
- function ChainRulesCore. rrule (:: typeof (first), x:: Array{T} ) where T
55
- x_dims = size (x)
56
- function first_pullback (ȳ)
57
- x̄_ret = InplaceableThunk (
58
- Thunk () do
59
- x̄ = zeros (T, x_dims)
60
- x̄[1 ]= ȳ
61
- x̄
62
- end ,
63
- ā -> (ā[1 ] += ȳ; ā)
64
- )
52
+ function ChainRulesCore. rrule (:: typeof (identity), x:: Array )
53
+ function identity_pullback (ȳ)
54
+ x̄_ret = InplaceableThunk (@thunk (ȳ), ā -> (inplace_used= true ; ā .+ = ȳ))
65
55
return (NO_FIELDS, x̄_ret)
66
56
end
67
- return first (x), first_pullback
57
+ return identity (x), identity_pullback
68
58
end
69
59
70
- frule_test (first, (randn (4 ), randn (4 )))
71
- rrule_test (first, randn (), (randn (4 ), randn (4 )))
60
+ inplace_used = false
61
+ frule_test (identity, (randn (4 ), randn (4 )))
62
+ @test inplace_used # make sure we are using, and thus testing the add!
63
+
64
+ inplace_used = false
65
+ rrule_test (identity, randn (4 ), (randn (4 ), randn (4 )))
66
+ @test inplace_used # make sure we are using, and thus testing the add!
72
67
end
73
68
74
- @testset " Incorrect inplace definitions" begin
75
- my_first (value) = first (value) # we are going to define bad rules on this
76
- function ChainRulesCore. frule ((_, ẋ), :: typeof (my_first), x:: Array )
77
- ẏ = InplaceableThunk (
78
- @thunk (first (ẋ)), # correct
79
- ȧ -> ȧ + 1000 * first (ẋ), # incorrect (also not actually inplace)
80
- )
81
- return first (x), ẏ
69
+ @testset " Incorrect in-place definitions" begin
70
+ my_identity (value) = value # we will define bad rules on this
71
+ function ChainRulesCore. frule ((_, ẋ), :: typeof (my_identity), x:: Array )
72
+ # only the in-place part is incorrect
73
+ ẏ = InplaceableThunk (@thunk (ẋ), ȧ -> ȧ .+ = 200 .* ẋ)
74
+ return my_identity (x), ẏ
82
75
end
83
- function ChainRulesCore. rrule (:: typeof (my_first ), x:: Array{T} ) where T
76
+ function ChainRulesCore. rrule (:: typeof (my_identity ), x:: Array )
84
77
x_dims = size (x)
85
- function my_first_pullback (ȳ)
86
- x̄_ret = InplaceableThunk (
87
- Thunk () do # correct
88
- x̄ = zeros (T, x_dims)
89
- x̄[1 ]= ȳ
90
- x̄
91
- end ,
92
- ā -> (ā[1 ] += 1000 * ȳ; ā) # incorrect
93
- )
78
+ function my_identity_pullback (ȳ)
79
+ # only the in-place part is incorrect
80
+ x̄_ret = InplaceableThunk (@thunk (ȳ), ā -> ā .+ = 200 .* ȳ)
94
81
return (NO_FIELDS, x̄_ret)
95
82
end
96
- return first (x), my_first_pullback
83
+ return my_identity (x), my_identity_pullback
97
84
end
98
-
99
- @test fails (()-> frule_test (my_first, (randn (4 ), randn (4 ))))
100
- @test fails (()-> rrule_test (my_first, randn (), (randn (4 ), randn (4 ))))
85
+ @test fails (()-> frule_test (my_identity, (randn (4 ), randn (4 ))))
86
+ @test fails (()-> rrule_test (my_identity, randn (4 ), (randn (4 ), randn (4 ))))
101
87
end
102
88
end
103
89
@@ -141,9 +127,9 @@ primalapprox(x) = x
141
127
simo_pullback ((a, b)) = (NO_FIELDS, a .+ 2 .* b)
142
128
return simo (x), simo_pullback
143
129
end
144
- function ChainRulesCore. frule ((_, ẋ ), simo, x)
130
+ function ChainRulesCore. frule ((_, ẋ ), simo, x)
145
131
y = simo (x)
146
- return y, Composite {typeof(y)} (ẋ, 2 ẋ )
132
+ return y, Composite {typeof(y)} (ẋ, 2 ẋ )
147
133
end
148
134
149
135
@testset " frule_test" begin
@@ -198,8 +184,8 @@ primalapprox(x) = x
198
184
end
199
185
200
186
@testset " unary with kwargs: futestkws(x; err)" begin
201
- function ChainRulesCore. frule ((_, ẋ ), :: typeof (futestkws), x; err = true )
202
- return futestkws (x; err = err), ẋ
187
+ function ChainRulesCore. frule ((_, ẋ ), :: typeof (futestkws), x; err = true )
188
+ return futestkws (x; err = err), ẋ
203
189
end
204
190
function ChainRulesCore. rrule (:: typeof (futestkws), x; err = true )
205
191
function futestkws_pullback (Δx)
@@ -232,8 +218,8 @@ primalapprox(x) = x
232
218
end
233
219
234
220
@testset " binary with kwargs: fbtestkws(x, y; err)" begin
235
- function ChainRulesCore. frule ((_, ẋ , _), :: typeof (fbtestkws), x, y; err = true )
236
- return fbtestkws (x, y; err = err), ẋ
221
+ function ChainRulesCore. frule ((_, ẋ , _), :: typeof (fbtestkws), x, y; err = true )
222
+ return fbtestkws (x, y; err = err), ẋ
237
223
end
238
224
function ChainRulesCore. rrule (:: typeof (fbtestkws), x, y; err = true )
239
225
function fbtestkws_pullback (Δx)
@@ -323,26 +309,26 @@ primalapprox(x) = x
323
309
return iterfun (iter), iterfun_pullback
324
310
end
325
311
326
- # This needs to be in a seperate testet to stop the `x` being shared with `iterfun`
312
+ # This needs to be in a separate testet to stop the `x` being shared with `iterfun`
327
313
@testset " Testing iterator function" begin
328
314
x = TestIterator (randn (2 , 3 ), Base. SizeUnknown (), Base. EltypeUnknown ())
329
- ẋ = TestIterator (randn (2 , 3 ), Base. SizeUnknown (), Base. EltypeUnknown ())
315
+ ẋ = TestIterator (randn (2 , 3 ), Base. SizeUnknown (), Base. EltypeUnknown ())
330
316
x̄ = TestIterator (randn (2 , 3 ), Base. SizeUnknown (), Base. EltypeUnknown ())
331
317
332
- frule_test (iterfun, (x, ẋ ))
318
+ frule_test (iterfun, (x, ẋ ))
333
319
rrule_test (iterfun, randn (), (x, x̄))
334
320
end
335
321
end
336
322
337
323
@testset " unhappy path" begin
338
324
@testset " primal wrong" begin
339
325
my_identity1 (x) = x
340
- function ChainRulesCore. frule ((_, ẏ ), :: typeof (my_identity1), x)
341
- return 2.5 * x, ẏ
326
+ function ChainRulesCore. frule ((_, ẏ ), :: typeof (my_identity1), x)
327
+ return 2.5 * x, ẏ
342
328
end
343
329
function ChainRulesCore. rrule (:: typeof (my_identity1), x)
344
- function identity_pullback (ȳ )
345
- return (NO_FIELDS, ȳ )
330
+ function identity_pullback (ȳ )
331
+ return (NO_FIELDS, ȳ )
346
332
end
347
333
return 2.5 * x, identity_pullback
348
334
end
@@ -351,14 +337,14 @@ primalapprox(x) = x
351
337
@test fails (()-> rrule_test (my_identity1, 4.1 , (2.2 , 3.3 )))
352
338
end
353
339
354
- @testset " deriviative wrong" begin
340
+ @testset " derivative wrong" begin
355
341
my_identity2 (x) = x
356
- function ChainRulesCore. frule ((_, ẏ ), :: typeof (my_identity2), x)
357
- return x, 2.7 * ẏ
342
+ function ChainRulesCore. frule ((_, ẏ ), :: typeof (my_identity2), x)
343
+ return x, 2.7 * ẏ
358
344
end
359
345
function ChainRulesCore. rrule (:: typeof (my_identity2), x)
360
- function identity_pullback (ȳ )
361
- return (NO_FIELDS, 31.8 * ȳ )
346
+ function identity_pullback (ȳ )
347
+ return (NO_FIELDS, 31.8 * ȳ )
362
348
end
363
349
return x, identity_pullback
364
350
end
0 commit comments