Skip to content

Commit b957872

Browse files
Differentiation Dirichlet (#1534)
* constructor frule * frule tested * rrule tests * logpdf test * signature for conflict * TestUtils out of Project * ChainRules itself not needed (yet?) * remove checkarg * Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * Update test/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * Update test/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * Update test/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * conflict * eltype instability * single loop * fix tests * forward finite diff * switch to broadcast * fix broadcast * switch off-support value to NaN * Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * Update src/multivariate/dirichlet.jl Co-authored-by: David Widmann <[email protected]> * do not assume inplace * fixed temp * Simplify implementation and tests in #1534 (#1555) * Simplify implementation and tests * Precompute `digamma(alpha0)` * Relax type signature Co-authored-by: David Widmann <[email protected]>
1 parent 7c3af32 commit b957872

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

src/multivariate/dirichlet.jl

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))
7272
length(d::Dirichlet) = length(d.alpha)
7373
mean(d::Dirichlet) = d.alpha .* inv(d.alpha0)
7474
params(d::Dirichlet) = (d.alpha,)
75-
@inline partype(d::Dirichlet{T}) where {T<:Real} = T
75+
@inline partype(::Dirichlet{T}) where {T<:Real} = T
7676

7777
function var(d::Dirichlet)
7878
α0 = d.alpha0
@@ -375,3 +375,62 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
375375
elogp = mean_logp(suffstats(Dirichlet, P, w))
376376
fit_dirichlet!(elogp, α; maxiter=maxiter, tol=tol, debug=debug)
377377
end
378+
379+
## Differentiation
380+
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
381+
d = DT(alpha; check_args=check_args)
382+
∂alpha0 = sum(Δalpha)
383+
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
384+
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
385+
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
386+
end))
387+
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
388+
return d, Δd
389+
end
390+
391+
function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
392+
d = DT(alpha; check_args=check_args)
393+
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
394+
function Dirichlet_pullback(_Δd)
395+
Δd = ChainRulesCore.unthunk(_Δd)
396+
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
397+
return ChainRulesCore.NoTangent(), Δalpha
398+
end
399+
return d, Dirichlet_pullback
400+
end
401+
402+
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
403+
Ω = _logpdf(d, x)
404+
∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
405+
xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
406+
end))
407+
∂lmnB = -Δd.lmnB
408+
ΔΩ = ∂alpha + ∂lmnB
409+
if !isfinite(Ω)
410+
ΔΩ = oftype(ΔΩ, NaN)
411+
end
412+
return Ω, ΔΩ
413+
end
414+
415+
function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
416+
Ω = _logpdf(d, x)
417+
isfinite_Ω = isfinite(Ω)
418+
alpha = d.alpha
419+
function _logpdf_Dirichlet_pullback(_ΔΩ)
420+
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
421+
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
422+
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
423+
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
424+
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
425+
return ChainRulesCore.NoTangent(), Δd, Δx
426+
end
427+
return Ω, _logpdf_Dirichlet_pullback
428+
end
429+
function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool)
430+
∂alphai = xlogy.(ΔΩi, xi)
431+
return isfinite ? ∂alphai : oftype(∂alphai, NaN)
432+
end
433+
function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool)
434+
Δxi = ΔΩi * (alphai - 1) / xi
435+
return isfinite ? Δxi : oftype(Δxi, NaN)
436+
end

test/dirichlet.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
using Distributions
44
using Test, Random, LinearAlgebra
5-
5+
using ChainRulesCore
6+
using ChainRulesTestUtils
7+
using FiniteDifferences
68

79
Random.seed!(34567)
810

@@ -127,3 +129,32 @@ end
127129
@test entropy(Dirichlet(N, 1)) -loggamma(N)
128130
@test entropy(Dirichlet(ones(N))) -loggamma(N)
129131
end
132+
133+
@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10)
134+
alpha = rand(n)
135+
d = Dirichlet(alpha)
136+
137+
@testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64})
138+
# Avoid issues with finite differencing if values in `alpha` become negative or zero
139+
# by using forward differencing
140+
test_frule(T, alpha; fdm=forward_fdm(5, 1))
141+
test_rrule(T, alpha; fdm=forward_fdm(5, 1))
142+
end
143+
144+
@testset "_logpdf" begin
145+
# `x1` is in the support, `x2` isn't
146+
x1 = rand(n)
147+
x1 ./= sum(x1)
148+
x2 = x1 .+ 1
149+
150+
# Use special finite differencing method that tries to avoid moving outside of the
151+
# support by limiting the range of the points around the input that are evaluated
152+
fdm = central_fdm(5, 1; max_range=1e-9)
153+
154+
for x in (x1, x2)
155+
# We have to adjust the tolerance since the finite differencing method is rough
156+
test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
157+
test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
158+
end
159+
end
160+
end

0 commit comments

Comments
 (0)