Skip to content

Commit 7327272

Browse files
committed
Fix issue #586
1 parent 5855c10 commit 7327272

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

test/rulesets/Base/broadcast.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,28 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
1515
y1, bk1 = rrule(CFG, copybroadcasted, BS1, >, rand(3), rand(3))
1616
@test y1 isa AbstractArray{Bool}
1717
@test all(d -> d isa AbstractZero, bk1(99))
18-
18+
1919
y2, bk2 = rrule(CFG, copybroadcasted, BT1, isinteger, Tuple(rand(3)))
2020
@test y2 isa Tuple{Bool,Bool,Bool}
2121
@test all(d -> d isa AbstractZero, bk2(99))
2222
end
2323

2424
@testset "split 2: derivatives" begin
2525
test_rrule(copybroadcasted, BS1, log, rand(3) .+ 1)
26-
test_rrule(copybroadcasted, BT1, log, Tuple(rand(3) .+ 1))
26+
# `check_inferred` doesn't accept the `Union` returned from ProjectTo as of
27+
# ChainRuleCore 1.15.4 https://github.com/JuliaDiff/ChainRulesCore.jl/issues/586
28+
test_rrule(copybroadcasted, BT1, log, Tuple(rand(3) .+ 1); check_inferred=false)
2729

2830
# Two args uses StructArrays
2931
test_rrule(copybroadcasted, BS1, atan, rand(3), rand(3))
3032
test_rrule(copybroadcasted, BS2, atan, rand(3), rand(4)')
3133
test_rrule(copybroadcasted, BS1, atan, rand(3), rand())
3234
test_rrule(copybroadcasted, BT1, atan, rand(3), Tuple(rand(1)))
3335
test_rrule(copybroadcasted, BT1, atan, Tuple(rand(3)), Tuple(rand(3)), check_inferred = VERSION > v"1.7")
34-
36+
3537
# test_rrule(copy∘broadcasted, *, BS1, rand(3), Ref(rand())) # don't know what I was testing
3638
end
37-
39+
3840
@testset "split 3: forwards" begin
3941
# In test_helpers.jl, `flog` and `fstar` have only `frule`s defined, nothing else.
4042
test_rrule(copybroadcasted, BS1, flog, rand(3))
@@ -57,14 +59,14 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
5759
test_rrule(copybroadcasted, BS2, Multiplier(rand()), rand(3), rand(4)', check_inferred=false) # Union{ZeroTangent, Tangent{Multiplier{...
5860
@test_skip test_rrule(copybroadcasted, BS1, Multiplier(rand()), rand(3), 5.0im, check_inferred=false) # ProjectTo(f) fails to remove the imaginary part of Multiplier's gradient
5961
test_rrule(copybroadcasted, BS1, make_two_vec, rand(3), check_inferred=false)
60-
62+
6163
# Non-diff components -- note that with BroadcastStyle, Ref is from e.g. Broadcast.broadcastable(nothing)
6264
test_rrule(copybroadcasted, BS2, firsttuple, rand(3), Ref(:sym), rand(4)', check_inferred=false)
6365
test_rrule(copybroadcasted, BS2, lasttuple, rand(3), Ref(nothing), rand(4)', check_inferred=false)
6466
test_rrule(copybroadcasted, BS1, |>, rand(3), Ref(sin), check_inferred=false)
6567
_call(f, x...) = f(x...)
6668
test_rrule(copybroadcasted, BS2, _call, Ref(atan), rand(3), rand(4)', check_inferred=false)
67-
69+
6870
test_rrule(copybroadcasted, BS1, getindex, [rand(3) for _ in 1:2], [3,1], check_inferred=false)
6971
test_rrule(copybroadcasted, BS1, getindex, [rand(3) for _ in 1:2], (3,1), check_inferred=false)
7072
test_rrule(copybroadcasted, BS1, getindex, [rand(3) for _ in 1:2], Ref(CartesianIndex(2)), check_inferred=false)
@@ -86,20 +88,20 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
8688
@gpu test_rrule(copybroadcasted, +, rand(3), 1.0*im)
8789
@gpu test_rrule(copybroadcasted, +, rand(3), true)
8890
@gpu_broken test_rrule(copybroadcasted, +, rand(3), Tuple(rand(3)))
89-
91+
9092
@gpu test_rrule(copybroadcasted, -, rand(3), rand(3))
9193
@gpu test_rrule(copybroadcasted, -, rand(3), rand(4)')
9294
@gpu test_rrule(copybroadcasted, -, rand(3))
9395
test_rrule(copybroadcasted, -, Tuple(rand(3)))
94-
96+
9597
@gpu test_rrule(copybroadcasted, *, rand(3), rand(3))
9698
@gpu test_rrule(copybroadcasted, *, rand(3), rand())
9799
@gpu test_rrule(copybroadcasted, *, rand(), rand(3))
98100

99101
test_rrule(copybroadcasted, *, rand(3) .+ im, rand(3) .+ 2im)
100102
test_rrule(copybroadcasted, *, rand(3) .+ im, rand() + 3im)
101103
test_rrule(copybroadcasted, *, rand() + im, rand(3) .+ 4im)
102-
104+
103105
@test_skip test_rrule(copybroadcasted, *, im, rand(3)) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}})
104106
@test_skip test_rrule(copybroadcasted, *, rand(3), im) # MethodError: no method matching randn(::Random._GLOBAL_RNG, ::Type{Complex{Bool}})
105107
y4, bk4 = rrule(CFG, copybroadcasted, *, im, [1,2,3.0])
@@ -113,16 +115,16 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
113115

114116
@gpu test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3), Val(2))
115117
@gpu test_rrule(copybroadcasted, Base.literal_pow, ^, rand(3) .+ im, Val(2))
116-
118+
117119
@gpu test_rrule(copybroadcasted, /, rand(3), rand())
118120
@gpu test_rrule(copybroadcasted, /, rand(3) .+ im, rand() + 3im)
119121
end
120122
@testset "identity etc" begin
121123
test_rrule(copybroadcasted, identity, rand(3))
122-
124+
123125
test_rrule(copybroadcasted, Float32, rand(3), rtol=1e-4)
124126
test_rrule(copybroadcasted, ComplexF32, rand(3), rtol=1e-4)
125-
127+
126128
test_rrule(copybroadcasted, float, rand(3))
127129
end
128130
@testset "complex" begin
@@ -136,7 +138,7 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
136138

137139
test_rrule(copybroadcasted, imag, rand(3))
138140
test_rrule(copybroadcasted, imag, rand(3) .+ im .* rand.())
139-
141+
140142
test_rrule(copybroadcasted, complex, rand(3))
141143
end
142144
end
@@ -173,9 +175,9 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
173175
test_rrule(copybroadcasted, complex, rand())
174176
end
175177
end
176-
178+
177179
@testset "bugs" begin
178180
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
179181
@test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)
180182
end
181-
end
183+
end

0 commit comments

Comments
 (0)