Skip to content

Commit 54995d6

Browse files
committed
Add rules for det and logdet of Cholesky
1 parent c5dbe03 commit 54995d6

File tree

4 files changed

+43
-3
lines changed

4 files changed

+43
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.29.0"
3+
version = "1.30.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/rulesets/LinearAlgebra/dense.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,15 @@ end
118118
##### `det`
119119
#####
120120

121-
function frule((_, Δx), ::typeof(det), x::AbstractMatrix)
121+
function frule((_, Δx), ::typeof(det), x::StridedMatrix{<:Number})
122122
Ω = det(x)
123123
# TODO Performance optimization: probably there is an efficent
124124
# way to compute this trace without during the full compution within
125125
return Ω, Ω * tr(x \ Δx)
126126
end
127127
frule((_, Δx), ::typeof(det), x::Number) = (det(x), Δx)
128128

129-
function rrule(::typeof(det), x::Union{Number, AbstractMatrix})
129+
function rrule(::typeof(det), x::Union{Number, StridedMatrix{<:Number}})
130130
Ω = det(x)
131131
function det_pullback(ΔΩ)
132132
∂x = x isa Number ? ΔΩ : inv(x)' * dot(Ω, ΔΩ)

src/rulesets/LinearAlgebra/factorization.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,3 +551,24 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
551551
end
552552
return getproperty(F, x), getproperty_cholesky_pullback
553553
end
554+
555+
# `det` and `logdet` for `Cholesky`
556+
function rrule(::typeof(det), C::Cholesky)
557+
y = det(C)
558+
s = conj!((2 * y) ./ _diag_view(C.factors))
559+
function det_Cholesky_pullback(ȳ)
560+
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
561+
return NoTangent(), ΔC
562+
end
563+
return y, det_Cholesky_pullback
564+
end
565+
566+
function rrule(::typeof(logdet), C::Cholesky)
567+
y = logdet(C)
568+
s = conj!((2 * one(eltype(C))) ./ _diag_view(C.factors))
569+
function logdet_Cholesky_pullback(ȳ)
570+
ΔC = Tangent{typeof(C)}(; factors=Diagonal(ȳ .* s))
571+
return NoTangent(), ΔC
572+
end
573+
return y, logdet_Cholesky_pullback
574+
end

test/rulesets/LinearAlgebra/factorization.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,5 +432,24 @@ end
432432
ΔX_symmetric = chol_back_sym(Δ)[2]
433433
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
434434
end
435+
436+
@testset "det and logdet (uplo=$p)" for p in ['U', 'L']
437+
@testset "$op" for op in (det, logdet)
438+
@testset "$T" for T in (Float64, ComplexF64)
439+
n = 5
440+
# rand (not randn) so det will be postive, so logdet will be defined
441+
A = 3 * rand(T, (n, n))
442+
X = Cholesky((p === 'U' ? UpperTriangular : LowerTriangular)(A * A' + I))
443+
X̄_acc = Tangent{typeof(X)}(; factors=Diagonal(randn(T, n))) # sensitivity is always a diagonal
444+
test_rrule(op, X X̄_acc)
445+
446+
# return type
447+
_, op_pullback = rrule(op, X)
448+
= op_pullback(2.7)[2]
449+
@testisa Tangent{<:Cholesky}
450+
@test.factors isa Diagonal
451+
end
452+
end
453+
end
435454
end
436455
end

0 commit comments

Comments
 (0)