@@ -382,57 +382,112 @@ end
382
382
# have fantastic support for this stuff at the minute.
383
383
# also we might be missing some overloads for different tangent-types in the rules
384
384
@testset " cholesky" begin
385
- @testset " Real" begin
386
- test_rrule (cholesky, 0.8 )
385
+ @testset " Number" begin
386
+ @testset " uplo=$uplo " for uplo in (:U , :L )
387
+ test_rrule (cholesky, 0.8 , uplo)
388
+ test_rrule (cholesky, - 0.3 , uplo)
389
+ test_rrule (cholesky, 0.23 + 0im , uplo)
390
+ test_rrule (cholesky, 0.78 + 0.5im , uplo)
391
+ test_rrule (cholesky, - 0.34 + 0.1im , uplo)
392
+ end
387
393
end
388
- @testset " Diagonal{<:Real}" begin
389
- D = Diagonal (rand (5 ) .+ 0.1 )
390
- C = cholesky (D)
391
- test_rrule (
392
- cholesky, D ⊢ Diagonal (randn (5 )), Val (false );
393
- output_tangent= Tangent {typeof(C)} (factors= Diagonal (randn (5 )))
394
- )
394
+
395
+ @testset " Diagonal" begin
396
+ @testset " Diagonal{<:Real}" begin
397
+ test_rrule (cholesky, Diagonal ([0.3 , 0.2 , 0.5 , 0.6 , 0.9 ]), Val (false ))
398
+ end
399
+ @testset " Diagonal{<:Complex}" begin
400
+ # finite differences in general will produce matrices with non-real
401
+ # diagonals, which cause factorization to fail. If we turn off the check and
402
+ # ensure the cotangent is real, then test_rrule still works.
403
+ D = Diagonal ([0.3 + 0im , 0.2 , 0.5 , 0.6 , 0.9 ])
404
+ C = cholesky (D)
405
+ test_rrule (
406
+ cholesky, D, Val (false );
407
+ output_tangent= Tangent {typeof(C)} (factors= complex (randn (5 , 5 ))),
408
+ fkwargs= (; check= false ),
409
+ )
410
+ end
411
+ @testset " check has correct default and passed to primal" begin
412
+ @test_throws Exception rrule (cholesky, Diagonal (- rand (5 )), Val (false ))
413
+ rrule (cholesky, Diagonal (- rand (5 )), Val (false ); check= false )
414
+ end
415
+ @testset " failed factorization" begin
416
+ A = Diagonal (vcat (rand (4 ), - rand (4 ), rand (4 )))
417
+ test_rrule (cholesky, A, Val (false ); fkwargs= (; check= false ))
418
+ end
395
419
end
396
420
397
- X = generate_well_conditioned_matrix (10 )
398
- V = generate_well_conditioned_matrix (10 )
399
- F, dX_pullback = rrule (cholesky, X, Val (false ))
400
- F_1arg, dX_pullback_1arg = rrule (cholesky, X) # to test not passing the Val(false)
401
- @test F == F_1arg
402
- @testset " uplo=$p " for p in [:U , :L ]
403
- Y, dF_pullback = rrule (getproperty, F, p)
404
- Ȳ = (p === :U ? UpperTriangular : LowerTriangular)(randn (size (Y)))
405
- (dself, dF, dp) = dF_pullback (Ȳ)
406
- @test dself === NoTangent ()
407
- @test dp === NoTangent ()
408
-
409
- # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
410
- # machinery from FiniteDifferences because that isn't set up to respect
411
- # necessary special properties of the input. In the case of the Cholesky
412
- # factorization, we need the input to be Hermitian.
413
- ΔF = unthunk (dF)
414
- _, dX, darg2 = dX_pullback (ΔF)
415
- _, dX_1arg = dX_pullback_1arg (ΔF)
416
- @test dX == dX_1arg
417
- @test darg2 === NoTangent ()
418
- X̄_ad = dot (unthunk (dX), V)
419
- X̄_fd = central_fdm (5 , 1 )(0.000_001 ) do ε
420
- dot (Ȳ, getproperty (cholesky (X .+ ε .* V), p))
421
+ @testset " StridedMatrix" begin
422
+ @testset " Matrix{$T }" for T in (Float64, ComplexF64)
423
+ X = generate_well_conditioned_matrix (T, 10 )
424
+ V = generate_well_conditioned_matrix (T, 10 )
425
+ F, dX_pullback = rrule (cholesky, X, Val (false ))
426
+ @testset " uplo=$p , cotangent eltype=$T " for p in [:U , :L ], S in unique ([T, complex (T)])
427
+ Y, dF_pullback = rrule (getproperty, F, p)
428
+ Ȳ = randn (S, size (Y))
429
+ (dself, dF, dp) = dF_pullback (Ȳ)
430
+ @test dself === NoTangent ()
431
+ @test dp === NoTangent ()
432
+
433
+ # NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
434
+ # machinery from FiniteDifferences because that isn't set up to respect
435
+ # necessary special properties of the input. In the case of the Cholesky
436
+ # factorization, we need the input to be Hermitian.
437
+ ΔF = unthunk (dF)
438
+ _, dX, darg2 = dX_pullback (ΔF)
439
+ @test darg2 === NoTangent ()
440
+ X̄_ad = real (dot (unthunk (dX), V))
441
+ X̄_fd = central_fdm (5 , 1 )(0.000_0001 ) do ε
442
+ real (dot (Ȳ, getproperty (cholesky (X .+ ε .* V), p)))
443
+ end
444
+ @test X̄_ad ≈ X̄_fd rtol= 1e-4
445
+ end
446
+ end
447
+ @testset " check has correct default and passed to primal" begin
448
+ # this will almost certainly be a non-PD matrix
449
+ X = Matrix (Symmetric (randn (10 , 10 )))
450
+ @test_throws Exception rrule (cholesky, X, Val (false ))
451
+ rrule (cholesky, X, Val (false ); check= false ) # just check it doesn't throw
421
452
end
422
- @test X̄_ad ≈ X̄_fd rtol= 1e-4
423
453
end
424
454
425
455
# Ensure that cotangents of cholesky(::StridedMatrix) and
426
456
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
427
457
@testset " Symmetric" begin
458
+ X = generate_well_conditioned_matrix (10 )
459
+ F, dX_pullback = rrule (cholesky, X, Val (false ))
460
+
428
461
X_symmetric, sym_back = rrule (Symmetric, X, :U )
429
462
C, chol_back_sym = rrule (cholesky, X_symmetric, Val (false ))
430
463
431
- Δ = Tangent {typeof(C)} ((U = UpperTriangular ( randn (size (X) ))))
464
+ Δ = Tangent {typeof(C)} ((factors = randn (size (X))))
432
465
ΔX_symmetric = chol_back_sym (Δ)[2 ]
433
466
@test sym_back (ΔX_symmetric)[2 ] ≈ dX_pullback (Δ)[2 ]
434
467
end
435
468
469
+ # Ensure that cotangents of cholesky(::StridedMatrix) and
470
+ # (cholesky ∘ Hermitian)(::StridedMatrix) are equal.
471
+ @testset " Hermitian" begin
472
+ @testset " Hermitian{$T }" for T in (Float64, ComplexF64)
473
+ X = generate_well_conditioned_matrix (T, 10 )
474
+ F, dX_pullback = rrule (cholesky, X, Val (false ))
475
+
476
+ X_hermitian, herm_back = rrule (Hermitian, X, :U )
477
+ C, chol_back_herm = rrule (cholesky, X_hermitian, Val (false ))
478
+
479
+ Δ = Tangent {typeof(C)} ((factors= randn (T, size (X))))
480
+ ΔX_hermitian = chol_back_herm (Δ)[2 ]
481
+ @test herm_back (ΔX_hermitian)[2 ] ≈ dX_pullback (Δ)[2 ]
482
+ end
483
+ @testset " check has correct default and passed to primal" begin
484
+ # this will almost certainly be a non-PD matrix
485
+ X = Hermitian (randn (10 , 10 ))
486
+ @test_throws Exception rrule (cholesky, X, Val (false ))
487
+ rrule (cholesky, X, Val (false ); check= false )
488
+ end
489
+ end
490
+
436
491
@testset " det and logdet (uplo=$p )" for p in (:U , :L )
437
492
@testset " $op " for op in (det, logdet)
438
493
@testset " $T " for T in (Float64, ComplexF64)
0 commit comments