Skip to content

Commit 6ff4c31

Browse files
sethaxendevmotionoxinabox
authored
Improvements to cholesky rrules (#630)
* Rewrite getproperty rule to store factors * Work with factors directly * Create tangent with factors * Simplify and generalize cholesky number rule * Use default tangent * Generalize diagonal cholesky to Hermitian * Simplify cholesky(::Diagonal) tests * Generalize and simplify cholesky(::StridedMatrix) * Fixes for Hermitian matrices * Generalize to complex Hermitian matrices * Remove unnecessary single-arg rule * Reformat * Check that check kwarg correctly passed * Support failed factorizations * Remove specializations for Thunks * Release unnecessary constraints on factors * Decrease step size * Check complex cotangent for real primal works * Fix diagonal rule for failed factorization * Release type constraint of Diagonal * Refer to real instead off complex * Increment patch number * Avoid unnecessary copies * Update src/rulesets/LinearAlgebra/factorization.jl Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: Frames Catherine White <[email protected]> * Complexify with concrete types Co-authored-by: David Widmann <[email protected]> Co-authored-by: Frames Catherine White <[email protected]>
1 parent a0d86fe commit 6ff4c31

File tree

3 files changed

+148
-94
lines changed

3 files changed

+148
-94
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.35.2"
3+
version = "1.35.3"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/factorization.jl

+57-58
Original file line numberDiff line numberDiff line change
@@ -441,46 +441,30 @@ end
441441
##### `cholesky`
442442
#####
443443

444-
# these functions are defined outside the rrule because otherwise type inference breaks
445-
# see https://github.com/JuliaLang/julia/issues/40990
446-
_cholesky_real_pullback(ΔC::Tangent, full_pb) = return full_pb(ΔC)[1:2]
447-
function _cholesky_real_pullback(Ȳ::AbstractThunk, full_pb)
448-
return _cholesky_real_pullback(unthunk(Ȳ), full_pb)
449-
end
450-
function rrule(::typeof(cholesky),
451-
A::Union{
452-
Real,
453-
Diagonal{<:Real},
454-
LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal,<:StridedMatrix},
455-
StridedMatrix{<:LinearAlgebra.BlasReal}
456-
}
457-
# Handle not passing in the uplo
458-
)
459-
arg2 = A isa Real ? :U : Val(false)
460-
C, full_pb = rrule(cholesky, A, arg2)
461-
462-
cholesky_pullback(ȳ) = return _cholesky_real_pullback(ȳ, full_pb)
463-
return C, cholesky_pullback
464-
end
465-
466-
function _cholesky_realuplo_pullback(ΔC::Tangent, C)
467-
return NoTangent(), ΔC.factors[1, 1] / (2 * C.U[1, 1]), NoTangent()
468-
end
469-
_cholesky_realuplo_pullback(Ȳ::AbstractThunk, C) = _cholesky_realuplo_pullback(unthunk(Ȳ), C)
470-
function rrule(::typeof(cholesky), A::Real, uplo::Symbol)
471-
C = cholesky(A, uplo)
472-
cholesky_pullback(ȳ) = _cholesky_realuplo_pullback(ȳ, C)
444+
function rrule(::typeof(cholesky), x::Number, uplo::Symbol)
445+
C = cholesky(x, uplo)
446+
function cholesky_pullback(ΔC)
447+
= real(only(unthunk(ΔC).factors)) / (2 * sign(real(x)) * only(C.factors))
448+
return NoTangent(), Ā, NoTangent()
449+
end
473450
return C, cholesky_pullback
474451
end
475452

476-
function _cholesky_Diagonal_pullback(ΔC::Tangent, C)
477-
= Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag))
478-
return NoTangent(), Ā, NoTangent()
453+
function _cholesky_Diagonal_pullback(ΔC, C)
454+
Udiag = C.factors.diag
455+
ΔUdiag = diag(ΔC.factors)
456+
Ādiag = real.(ΔUdiag) ./ (2 .* Udiag)
457+
if !issuccess(C)
458+
# cholesky computes the factor diagonal from the beginning until it encounters the
459+
# first failure. The remainder of the diagonal is then copied from the input.
460+
i = findfirst(x -> !isreal(x) || !(real(x) > 0), Udiag)
461+
Ādiag[i:end] .= ΔUdiag[i:end]
462+
end
463+
return NoTangent(), Diagonal(Ādiag), NoTangent()
479464
end
480-
_cholesky_Diagonal_pullback(Ȳ::AbstractThunk, C) = _cholesky_Diagonal_pullback(unthunk(Ȳ), C)
481-
function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Bool=true)
465+
function rrule(::typeof(cholesky), A::Diagonal{<:Number}, ::Val{false}; check::Bool=true)
482466
C = cholesky(A, Val(false); check=check)
483-
cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(, C)
467+
cholesky_pullback(ȳ) = _cholesky_Diagonal_pullback(unthunk(ȳ), C)
484468
return C, cholesky_pullback
485469
end
486470

@@ -489,69 +473,84 @@ end
489473
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
490474
function rrule(
491475
::typeof(cholesky),
492-
A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix},
476+
A::LinearAlgebra.RealHermSymComplexHerm{<:Real, <:StridedMatrix},
493477
::Val{false};
494478
check::Bool=true,
495479
)
496480
C = cholesky(A, Val(false); check=check)
497-
function _cholesky_HermOrSym_pullback(ΔC::Tangent)
498-
, U = _cholesky_pullback_shared_code(C, ΔC)
499-
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā)
481+
function cholesky_HermOrSym_pullback(ΔC)
482+
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
483+
rmul!(Ā, one(eltype(Ā)) / 2)
500484
return NoTangent(), _symhermtype(A)(Ā), NoTangent()
501485
end
502-
_cholesky_HermOrSym_pullback(Ȳ::AbstractThunk) = _cholesky_HermOrSym_pullback(unthunk(Ȳ))
503-
return C, _cholesky_HermOrSym_pullback
486+
return C, cholesky_HermOrSym_pullback
504487
end
505488

506489
function rrule(
507490
::typeof(cholesky),
508-
A::StridedMatrix{<:LinearAlgebra.BlasReal},
491+
A::StridedMatrix{<:Union{Real,Complex}},
509492
::Val{false};
510493
check::Bool=true,
511494
)
512495
C = cholesky(A, Val(false); check=check)
513-
function _cholesky_Strided_pullback(ΔC::Tangent)
514-
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
515-
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā)
496+
function cholesky_Strided_pullback(ΔC)
497+
= _cholesky_pullback_shared_code(C, unthunk(ΔC))
516498
idx = diagind(Ā)
517499
@views Ā[idx] .= real.(Ā[idx]) ./ 2
518500
return (NoTangent(), UpperTriangular(Ā), NoTangent())
519501
end
520-
_cholesky_Strided_pullback(Ȳ::AbstractThunk) = _cholesky_Strided_pullback(unthunk(Ȳ))
521-
return C, _cholesky_Strided_pullback
502+
return C, cholesky_Strided_pullback
522503
end
523504

524505
function _cholesky_pullback_shared_code(C, ΔC)
525-
U = C.U
526-
= ΔC.U
527-
= similar(U.data)
528-
= mul!(Ā, Ū, U')
529-
= LinearAlgebra.copytri!(Ā, 'U', true)
530-
= ldiv!(U, Ā)
531-
return Ā, U
506+
Δfactors = ΔC.factors
507+
= similar(C.factors)
508+
if C.uplo === 'U'
509+
U = C.U
510+
= eltype(U) <: Real ? real(_maybeUpperTri(Δfactors)) : _maybeUpperTri(Δfactors)
511+
mul!(Ā, Ū, U')
512+
LinearAlgebra.copytri!(Ā, 'U', true)
513+
eltype(Ā) <: Real || _realifydiag!(Ā)
514+
ldiv!(U, Ā)
515+
rdiv!(Ā, U')
516+
else # C.uplo === 'L'
517+
L = C.L
518+
= eltype(L) <: Real ? real(_maybeLowerTri(Δfactors)) : _maybeLowerTri(Δfactors)
519+
mul!(Ā, L', L̄)
520+
LinearAlgebra.copytri!(Ā, 'L', true)
521+
eltype(Ā) <: Real || _realifydiag!(Ā)
522+
rdiv!(Ā, L)
523+
ldiv!(L', Ā)
524+
end
525+
return
532526
end
533527

534528
function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
535529
function getproperty_cholesky_pullback(Ȳ)
536530
C = Tangent{T}
537531
∂F = if x === :U
538532
if F.uplo === 'U'
539-
C(U=UpperTriangular(Ȳ),)
533+
C(factors=_maybeUpperTri(Ȳ),)
540534
else
541-
C(L=LowerTriangular'),)
535+
C(factors=_maybeLowerTri'),)
542536
end
543537
elseif x === :L
544538
if F.uplo === 'L'
545-
C(L=LowerTriangular(Ȳ),)
539+
C(factors=_maybeLowerTri(Ȳ),)
546540
else
547-
C(U=UpperTriangular'),)
541+
C(factors=_maybeUpperTri'),)
548542
end
549543
end
550544
return NoTangent(), ∂F, NoTangent()
551545
end
552546
return getproperty(F, x), getproperty_cholesky_pullback
553547
end
554548

549+
_maybeUpperTri(A) = UpperTriangular(A)
550+
_maybeUpperTri(A::Diagonal) = A
551+
_maybeLowerTri(A) = LowerTriangular(A)
552+
_maybeLowerTri(A::Diagonal) = A
553+
555554
# `det` and `logdet` for `Cholesky`
556555
function rrule(::typeof(det), C::Cholesky)
557556
y = det(C)

test/rulesets/LinearAlgebra/factorization.jl

+90-35
Original file line numberDiff line numberDiff line change
@@ -382,57 +382,112 @@ end
382382
# have fantastic support for this stuff at the minute.
383383
# also we might be missing some overloads for different tangent-types in the rules
384384
@testset "cholesky" begin
385-
@testset "Real" begin
386-
test_rrule(cholesky, 0.8)
385+
@testset "Number" begin
386+
@testset "uplo=$uplo" for uplo in (:U, :L)
387+
test_rrule(cholesky, 0.8, uplo)
388+
test_rrule(cholesky, -0.3, uplo)
389+
test_rrule(cholesky, 0.23 + 0im, uplo)
390+
test_rrule(cholesky, 0.78 + 0.5im, uplo)
391+
test_rrule(cholesky, -0.34 + 0.1im, uplo)
392+
end
387393
end
388-
@testset "Diagonal{<:Real}" begin
389-
D = Diagonal(rand(5) .+ 0.1)
390-
C = cholesky(D)
391-
test_rrule(
392-
cholesky, D Diagonal(randn(5)), Val(false);
393-
output_tangent=Tangent{typeof(C)}(factors=Diagonal(randn(5)))
394-
)
394+
395+
@testset "Diagonal" begin
396+
@testset "Diagonal{<:Real}" begin
397+
test_rrule(cholesky, Diagonal([0.3, 0.2, 0.5, 0.6, 0.9]), Val(false))
398+
end
399+
@testset "Diagonal{<:Complex}" begin
400+
# finite differences in general will produce matrices with non-real
401+
# diagonals, which cause factorization to fail. If we turn off the check and
402+
# ensure the cotangent is real, then test_rrule still works.
403+
D = Diagonal([0.3 + 0im, 0.2, 0.5, 0.6, 0.9])
404+
C = cholesky(D)
405+
test_rrule(
406+
cholesky, D, Val(false);
407+
output_tangent=Tangent{typeof(C)}(factors=complex(randn(5, 5))),
408+
fkwargs=(; check=false),
409+
)
410+
end
411+
@testset "check has correct default and passed to primal" begin
412+
@test_throws Exception rrule(cholesky, Diagonal(-rand(5)), Val(false))
413+
rrule(cholesky, Diagonal(-rand(5)), Val(false); check=false)
414+
end
415+
@testset "failed factorization" begin
416+
A = Diagonal(vcat(rand(4), -rand(4), rand(4)))
417+
test_rrule(cholesky, A, Val(false); fkwargs=(; check=false))
418+
end
395419
end
396420

397-
X = generate_well_conditioned_matrix(10)
398-
V = generate_well_conditioned_matrix(10)
399-
F, dX_pullback = rrule(cholesky, X, Val(false))
400-
F_1arg, dX_pullback_1arg = rrule(cholesky, X) # to test not passing the Val(false)
401-
@test F == F_1arg
402-
@testset "uplo=$p" for p in [:U, :L]
403-
Y, dF_pullback = rrule(getproperty, F, p)
404-
= (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
405-
(dself, dF, dp) = dF_pullback(Ȳ)
406-
@test dself === NoTangent()
407-
@test dp === NoTangent()
408-
409-
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
410-
# machinery from FiniteDifferences because that isn't set up to respect
411-
# necessary special properties of the input. In the case of the Cholesky
412-
# factorization, we need the input to be Hermitian.
413-
ΔF = unthunk(dF)
414-
_, dX, darg2 = dX_pullback(ΔF)
415-
_, dX_1arg = dX_pullback_1arg(ΔF)
416-
@test dX == dX_1arg
417-
@test darg2 === NoTangent()
418-
X̄_ad = dot(unthunk(dX), V)
419-
X̄_fd = central_fdm(5, 1)(0.000_001) do ε
420-
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
421+
@testset "StridedMatrix" begin
422+
@testset "Matrix{$T}" for T in (Float64, ComplexF64)
423+
X = generate_well_conditioned_matrix(T, 10)
424+
V = generate_well_conditioned_matrix(T, 10)
425+
F, dX_pullback = rrule(cholesky, X, Val(false))
426+
@testset "uplo=$p, cotangent eltype=$T" for p in [:U, :L], S in unique([T, complex(T)])
427+
Y, dF_pullback = rrule(getproperty, F, p)
428+
= randn(S, size(Y))
429+
(dself, dF, dp) = dF_pullback(Ȳ)
430+
@test dself === NoTangent()
431+
@test dp === NoTangent()
432+
433+
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
434+
# machinery from FiniteDifferences because that isn't set up to respect
435+
# necessary special properties of the input. In the case of the Cholesky
436+
# factorization, we need the input to be Hermitian.
437+
ΔF = unthunk(dF)
438+
_, dX, darg2 = dX_pullback(ΔF)
439+
@test darg2 === NoTangent()
440+
X̄_ad = real(dot(unthunk(dX), V))
441+
X̄_fd = central_fdm(5, 1)(0.000_0001) do ε
442+
real(dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p)))
443+
end
444+
@test X̄_ad X̄_fd rtol=1e-4
445+
end
446+
end
447+
@testset "check has correct default and passed to primal" begin
448+
# this will almost certainly be a non-PD matrix
449+
X = Matrix(Symmetric(randn(10, 10)))
450+
@test_throws Exception rrule(cholesky, X, Val(false))
451+
rrule(cholesky, X, Val(false); check=false) # just check it doesn't throw
421452
end
422-
@test X̄_ad X̄_fd rtol=1e-4
423453
end
424454

425455
# Ensure that cotangents of cholesky(::StridedMatrix) and
426456
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
427457
@testset "Symmetric" begin
458+
X = generate_well_conditioned_matrix(10)
459+
F, dX_pullback = rrule(cholesky, X, Val(false))
460+
428461
X_symmetric, sym_back = rrule(Symmetric, X, :U)
429462
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))
430463

431-
Δ = Tangent{typeof(C)}((U=UpperTriangular(randn(size(X)))))
464+
Δ = Tangent{typeof(C)}((factors=randn(size(X))))
432465
ΔX_symmetric = chol_back_sym(Δ)[2]
433466
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
434467
end
435468

469+
# Ensure that cotangents of cholesky(::StridedMatrix) and
470+
# (cholesky ∘ Hermitian)(::StridedMatrix) are equal.
471+
@testset "Hermitian" begin
472+
@testset "Hermitian{$T}" for T in (Float64, ComplexF64)
473+
X = generate_well_conditioned_matrix(T, 10)
474+
F, dX_pullback = rrule(cholesky, X, Val(false))
475+
476+
X_hermitian, herm_back = rrule(Hermitian, X, :U)
477+
C, chol_back_herm = rrule(cholesky, X_hermitian, Val(false))
478+
479+
Δ = Tangent{typeof(C)}((factors=randn(T, size(X))))
480+
ΔX_hermitian = chol_back_herm(Δ)[2]
481+
@test herm_back(ΔX_hermitian)[2] dX_pullback(Δ)[2]
482+
end
483+
@testset "check has correct default and passed to primal" begin
484+
# this will almost certainly be a non-PD matrix
485+
X = Hermitian(randn(10, 10))
486+
@test_throws Exception rrule(cholesky, X, Val(false))
487+
rrule(cholesky, X, Val(false); check=false)
488+
end
489+
end
490+
436491
@testset "det and logdet (uplo=$p)" for p in (:U, :L)
437492
@testset "$op" for op in (det, logdet)
438493
@testset "$T" for T in (Float64, ComplexF64)

0 commit comments

Comments
 (0)