Skip to content

Differentiation Dirichlet #1534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ae0a0ad
constructor frule
matbesancon Apr 16, 2022
d407bf2
frule tested
matbesancon Apr 16, 2022
d5a293a
rrule tests
matbesancon Apr 16, 2022
f9de7b3
logpdf test
matbesancon Apr 17, 2022
3723789
signature for conflict
matbesancon Apr 17, 2022
5e32f04
TestUtils out of Project
matbesancon Apr 24, 2022
0dde72f
ChainRules itself not needed (yet?)
matbesancon Apr 24, 2022
1348792
remove checkarg
matbesancon Apr 24, 2022
90455c8
Update src/multivariate/dirichlet.jl
matbesancon Apr 25, 2022
25a41f3
Update test/dirichlet.jl
matbesancon Apr 25, 2022
89a9346
Update test/dirichlet.jl
matbesancon Apr 25, 2022
96883e8
Update test/dirichlet.jl
matbesancon Apr 25, 2022
ab01122
Update src/multivariate/dirichlet.jl
matbesancon Apr 25, 2022
1d79fec
conflict
matbesancon Apr 25, 2022
bc29c40
Merge branch 'cr-dirichlet' of github.com:JuliaStats/Distributions.jl…
matbesancon Apr 25, 2022
4cc7509
eltype instability
matbesancon Apr 25, 2022
0500772
single loop
matbesancon Apr 25, 2022
d2f832b
fix tests
matbesancon Apr 25, 2022
77ccee6
forward finite diff
matbesancon Apr 25, 2022
feafacd
switch to broadcast
matbesancon Apr 25, 2022
1f06aa6
fix broadcast
matbesancon Apr 25, 2022
e702017
switch off-support value to NaN
matbesancon Apr 25, 2022
475a934
Update src/multivariate/dirichlet.jl
matbesancon Apr 29, 2022
7515e86
Update src/multivariate/dirichlet.jl
matbesancon Apr 29, 2022
1a3fdd9
do not assume inplace
matbesancon May 1, 2022
76cc96a
conflict
matbesancon May 1, 2022
cb4f07e
fixed temp
matbesancon May 23, 2022
9234155
Simplify implementation and tests in #1534 (#1555)
devmotion May 24, 2022
2c7100e
conflict
matbesancon May 24, 2022
7297260
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Jul 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))
length(d::Dirichlet) = length(d.alpha)
mean(d::Dirichlet) = d.alpha .* inv(d.alpha0)
params(d::Dirichlet) = (d.alpha,)
@inline partype(d::Dirichlet{T}) where {T<:Real} = T
@inline partype(::Dirichlet{T}) where {T<:Real} = T

function var(d::Dirichlet)
α0 = d.alpha0
Expand Down Expand Up @@ -375,3 +375,62 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
elogp = mean_logp(suffstats(Dirichlet, P, w))
fit_dirichlet!(elogp, α; maxiter=maxiter, tol=tol, debug=debug)
end

## Differentiation
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
∂alpha0 = sum(Δalpha)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
end))
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
return d, Δd
end

function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
function Dirichlet_pullback(_Δd)
Δd = ChainRulesCore.unthunk(_Δd)
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), Δalpha
end
return d, Dirichlet_pullback
end

function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
Ω = _logpdf(d, x)
∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
end))
∂lmnB = -Δd.lmnB
ΔΩ = ∂alpha + ∂lmnB
if !isfinite(Ω)
ΔΩ = oftype(ΔΩ, NaN)
end
return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
Ω = _logpdf(d, x)
isfinite_Ω = isfinite(Ω)
alpha = d.alpha
function _logpdf_Dirichlet_pullback(_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
return ChainRulesCore.NoTangent(), Δd, Δx
end
return Ω, _logpdf_Dirichlet_pullback
end
function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool)
∂alphai = xlogy.(ΔΩi, xi)
return isfinite ? ∂alphai : oftype(∂alphai, NaN)
end
function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool)
Δxi = ΔΩi * (alphai - 1) / xi
return isfinite ? Δxi : oftype(Δxi, NaN)
end
33 changes: 32 additions & 1 deletion test/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

using Distributions
using Test, Random, LinearAlgebra

using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences

Random.seed!(34567)

Expand Down Expand Up @@ -127,3 +129,32 @@ end
@test entropy(Dirichlet(N, 1)) ≈ -loggamma(N)
@test entropy(Dirichlet(ones(N))) ≈ -loggamma(N)
end

@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10)
alpha = rand(n)
d = Dirichlet(alpha)

@testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64})
# Avoid issues with finite differencing if values in `alpha` become negative or zero
# by using forward differencing
test_frule(T, alpha; fdm=forward_fdm(5, 1))
test_rrule(T, alpha; fdm=forward_fdm(5, 1))
end

@testset "_logpdf" begin
# `x1` is in the support, `x2` isn't
x1 = rand(n)
x1 ./= sum(x1)
x2 = x1 .+ 1

# Use special finite differencing method that tries to avoid moving outside of the
# support by limiting the range of the points around the input that are evaluated
fdm = central_fdm(5, 1; max_range=1e-9)

for x in (x1, x2)
# We have to adjust the tolerance since the finite differencing method is rough
test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
end
end
end