@@ -500,14 +500,45 @@ end
500
500
@test 150_000_000 > @allocated gradient (loss, ones (1000 ,1000 ))
501
501
end
502
502
503
- @testset " tuples & broadcasting" begin
504
- @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,2 )) == ((2 ,2 ),)
505
- @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,)) == ((4 ,),)
506
- @test gradient (x -> sum (x .+ ones (2 ,1 )), (1 ,2 )) == ((1 ,1 ),)
507
-
508
- # https://github.com/FluxML/Zygote.jl/issues/975
509
- gt = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], (1 ,2 ))
510
- gv = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], [1 ,2 ])
511
- @test gt[1 ] == gv[1 ]
512
- @test collect (gt[2 ]) ≈ gv[2 ]
503
+ @testset " tricky broadcasting" begin
504
+ @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,2 )) == ((2 ,2 ),)
505
+ @test gradient (x -> sum (x .+ ones (2 ,2 )), (1 ,)) == ((4 ,),)
506
+ @test gradient (x -> sum (x .+ ones (2 ,1 )), (1 ,2 )) == ((1 ,1 ),)
507
+
508
+ # https://github.com/FluxML/Zygote.jl/issues/975
509
+ gt = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], (1 ,2 ))
510
+ gv = gradient ((x,p) -> prod (x .^ p), [3 ,4 ], [1 ,2 ])
511
+ @test gt[1 ] == gv[1 ]
512
+ @test collect (gt[2 ]) ≈ gv[2 ]
513
+
514
+ # closure captures y -- can't use ForwardDiff
515
+ @test gradient ((x,y) -> sum ((z-> z^ 2 + y[1 ]). (x)), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
516
+ @test gradient ((x,y) -> sum ((z-> z^ 2 + y[1 ]), x), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
517
+ @test gradient ((x,y) -> sum (map ((z-> z^ 2 + y[1 ]), x)), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
518
+ @test gradient ((x,y) -> mapreduce ((z-> z^ 2 + y[1 ]), + , x), [1 ,2 ,3 ], [4 ,5 ]) == ([2 , 4 , 6 ], [3 , 0 ])
519
+
520
+ # type unstable
521
+ @test gradient (xs -> sum ((x -> x< 2 ? false : x^ 2 ). (xs)), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
522
+ @test gradient (xs -> sum ((x -> x< 2 ? false : x^ 2 ), xs), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
523
+ @test gradient (xs -> sum (map ((x -> x< 2 ? false : x^ 2 ), xs)), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
524
+ @test gradient (xs -> mapreduce ((x -> x< 2 ? false : x^ 2 ), + , xs), [1 ,2 ,3 ])[1 ][2 : 3 ] == [4 , 6 ]
525
+
526
+ # with Ref, Val, Symbol
527
+ @test gradient (x -> sum (x .+ Ref (x[1 ])), [1 ,2 ,3 ]) == ([4 ,1 ,1 ],)
528
+ @test gradient (x -> sum (x .+ (x[1 ],)), [1 ,2 ,3 ]) == ([4 ,1 ,1 ],)
529
+ @test gradient (x -> sum ((first∘ tuple). (x, :ignore )), [1 ,2 ,3 ]) == ([1 ,1 ,1 ],)
530
+ @test gradient (x -> sum ((first∘ tuple). (x, Symbol)), [1 ,2 ,3 ]) == ([1 ,1 ,1 ],)
531
+ _f (x,:: Val{y} = Val (2 )) where {y} = x/ y
532
+ @test gradient (x -> sum (_f .(x, Val (2 ))), [1 ,2 ,3 ]) == ([0.5 , 0.5 , 0.5 ],)
533
+ @test gradient (x -> sum (_f .(x)), [1 ,2 ,3 ]) == ([0.5 , 0.5 , 0.5 ],)
534
+ @test gradient (x -> sum (map (_f, x)), [1 ,2 ,3 ]) == ([0.5 , 0.5 , 0.5 ],)
535
+
536
+ @test gradient (x -> sum (x ./ [1 ,2 ,4 ]), [1 ,2 ,pi ]) == ([1.0 , 0.5 , 0.25 ],)
537
+ @test gradient (x -> sum (map (/ , x, [1 ,2 ,4 ])), [1 ,2 ,pi ]) == ([1.0 , 0.5 , 0.25 ],)
538
+
539
+ # negative powers
540
+ @test gradient ((x,p) -> sum (x .^ p), [1.0 ,2.0 ,4.0 ], [1 ,- 1 ,2 ])[1 ] ≈ [1.0 , - 0.25 , 8.0 ]
541
+ @test gradient ((x,p) -> sum (x .^ p), [1.0 ,2.0 ,4.0 ], - 1 )[1 ] ≈ [- 1.0 , - 0.25 , - 0.0625 ]
542
+ @test gradient ((x,p) -> sum (z -> z^ p, x), [1.0 ,2.0 ,4.0 ], - 1 )[1 ] ≈ [- 1.0 , - 0.25 , - 0.0625 ]
543
+ @test gradient ((x,p) -> mapreduce (z -> z^ p, + , x), [1.0 ,2.0 ,4.0 ], - 1 )[1 ] ≈ [- 1.0 , - 0.25 , - 0.0625 ]
513
544
end
0 commit comments