@@ -15,26 +15,28 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
15
15
y1, bk1 = rrule (CFG, copy∘ broadcasted, BS1, > , rand (3 ), rand (3 ))
16
16
@test y1 isa AbstractArray{Bool}
17
17
@test all (d -> d isa AbstractZero, bk1 (99 ))
18
-
18
+
19
19
y2, bk2 = rrule (CFG, copy∘ broadcasted, BT1, isinteger, Tuple (rand (3 )))
20
20
@test y2 isa Tuple{Bool,Bool,Bool}
21
21
@test all (d -> d isa AbstractZero, bk2 (99 ))
22
22
end
23
23
24
24
@testset " split 2: derivatives" begin
25
25
test_rrule (copy∘ broadcasted, BS1, log, rand (3 ) .+ 1 )
26
- test_rrule (copy∘ broadcasted, BT1, log, Tuple (rand (3 ) .+ 1 ))
26
+ # `check_inferred` doesn't accept the `Union` returned from ProjectTo as of
27
+ # ChainRuleCore 1.15.4 https://github.com/JuliaDiff/ChainRulesCore.jl/issues/586
28
+ test_rrule (copy∘ broadcasted, BT1, log, Tuple (rand (3 ) .+ 1 ); check_inferred= false )
27
29
28
30
# Two args uses StructArrays
29
31
test_rrule (copy∘ broadcasted, BS1, atan, rand (3 ), rand (3 ))
30
32
test_rrule (copy∘ broadcasted, BS2, atan, rand (3 ), rand (4 )' )
31
33
test_rrule (copy∘ broadcasted, BS1, atan, rand (3 ), rand ())
32
34
test_rrule (copy∘ broadcasted, BT1, atan, rand (3 ), Tuple (rand (1 )))
33
35
test_rrule (copy∘ broadcasted, BT1, atan, Tuple (rand (3 )), Tuple (rand (3 )), check_inferred = VERSION > v " 1.7" )
34
-
36
+
35
37
# test_rrule(copy∘broadcasted, *, BS1, rand(3), Ref(rand())) # don't know what I was testing
36
38
end
37
-
39
+
38
40
@testset " split 3: forwards" begin
39
41
# In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
40
42
test_rrule (copy∘ broadcasted, BS1, flog, rand (3 ))
@@ -57,14 +59,14 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
57
59
test_rrule (copy∘ broadcasted, BS2, Multiplier (rand ()), rand (3 ), rand (4 )' , check_inferred= false ) # Union{ZeroTangent, Tangent{Multiplier{...
58
60
@test_skip test_rrule (copy∘ broadcasted, BS1, Multiplier (rand ()), rand (3 ), 5.0im , check_inferred= false ) # ProjectTo(f) fails to remove the imaginary part of Multiplier's gradient
59
61
test_rrule (copy∘ broadcasted, BS1, make_two_vec, rand (3 ), check_inferred= false )
60
-
62
+
61
63
# Non-diff components -- note that with BroadcastStyle, Ref is from e.g. Broadcast.broadcastable(nothing)
62
64
test_rrule (copy∘ broadcasted, BS2, first∘ tuple, rand (3 ), Ref (:sym ), rand (4 )' , check_inferred= false )
63
65
test_rrule (copy∘ broadcasted, BS2, last∘ tuple, rand (3 ), Ref (nothing ), rand (4 )' , check_inferred= false )
64
66
test_rrule (copy∘ broadcasted, BS1, |> , rand (3 ), Ref (sin), check_inferred= false )
65
67
_call (f, x... ) = f (x... )
66
68
test_rrule (copy∘ broadcasted, BS2, _call, Ref (atan), rand (3 ), rand (4 )' , check_inferred= false )
67
-
69
+
68
70
test_rrule (copy∘ broadcasted, BS1, getindex, [rand (3 ) for _ in 1 : 2 ], [3 ,1 ], check_inferred= false )
69
71
test_rrule (copy∘ broadcasted, BS1, getindex, [rand (3 ) for _ in 1 : 2 ], (3 ,1 ), check_inferred= false )
70
72
test_rrule (copy∘ broadcasted, BS1, getindex, [rand (3 ) for _ in 1 : 2 ], Ref (CartesianIndex (2 )), check_inferred= false )
@@ -86,20 +88,20 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
86
88
@gpu test_rrule (copy∘ broadcasted, + , rand (3 ), 1.0 * im)
87
89
@gpu test_rrule (copy∘ broadcasted, + , rand (3 ), true )
88
90
@gpu_broken test_rrule (copy∘ broadcasted, + , rand (3 ), Tuple (rand (3 )))
89
-
91
+
90
92
@gpu test_rrule (copy∘ broadcasted, - , rand (3 ), rand (3 ))
91
93
@gpu test_rrule (copy∘ broadcasted, - , rand (3 ), rand (4 )' )
92
94
@gpu test_rrule (copy∘ broadcasted, - , rand (3 ))
93
95
test_rrule (copy∘ broadcasted, - , Tuple (rand (3 )))
94
-
96
+
95
97
@gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand (3 ))
96
98
@gpu test_rrule (copy∘ broadcasted, * , rand (3 ), rand ())
97
99
@gpu test_rrule (copy∘ broadcasted, * , rand (), rand (3 ))
98
100
99
101
test_rrule (copy∘ broadcasted, * , rand (3 ) .+ im, rand (3 ) .+ 2im )
100
102
test_rrule (copy∘ broadcasted, * , rand (3 ) .+ im, rand () + 3im )
101
103
test_rrule (copy∘ broadcasted, * , rand () + im, rand (3 ) .+ 4im )
102
-
104
+
103
105
@test_skip test_rrule (copy∘ broadcasted, * , im, rand (3 )) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}})
104
106
@test_skip test_rrule (copy∘ broadcasted, * , rand (3 ), im) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}})
105
107
y4, bk4 = rrule (CFG, copy∘ broadcasted, * , im, [1 ,2 ,3.0 ])
@@ -113,16 +115,16 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
113
115
114
116
@gpu test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ), Val (2 ))
115
117
@gpu test_rrule (copy∘ broadcasted, Base. literal_pow, ^ , rand (3 ) .+ im, Val (2 ))
116
-
118
+
117
119
@gpu test_rrule (copy∘ broadcasted, / , rand (3 ), rand ())
118
120
@gpu test_rrule (copy∘ broadcasted, / , rand (3 ) .+ im, rand () + 3im )
119
121
end
120
122
@testset " identity etc" begin
121
123
test_rrule (copy∘ broadcasted, identity, rand (3 ))
122
-
124
+
123
125
test_rrule (copy∘ broadcasted, Float32, rand (3 ), rtol= 1e-4 )
124
126
test_rrule (copy∘ broadcasted, ComplexF32, rand (3 ), rtol= 1e-4 )
125
-
127
+
126
128
test_rrule (copy∘ broadcasted, float, rand (3 ))
127
129
end
128
130
@testset " complex" begin
@@ -136,7 +138,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
136
138
137
139
test_rrule (copy∘ broadcasted, imag, rand (3 ))
138
140
test_rrule (copy∘ broadcasted, imag, rand (3 ) .+ im .* rand .())
139
-
141
+
140
142
test_rrule (copy∘ broadcasted, complex, rand (3 ))
141
143
end
142
144
end
@@ -173,9 +175,9 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
173
175
test_rrule (copy∘ broadcasted, complex, rand ())
174
176
end
175
177
end
176
-
178
+
177
179
@testset " bugs" begin
178
180
@test ChainRules. unbroadcast ((1 , 2 , [3 ]), [4 , 5 , [6 ]]) isa Tangent # earlier, NTuple demanded same type
179
181
@test ChainRules. unbroadcast (broadcasted (- , (1 , 2 ), 3 ), (4 , 5 )) == (4 , 5 ) # earlier, called ndims(::Tuple)
180
182
end
181
- end
183
+ end
0 commit comments