Skip to content

Differentiation Dirichlet #1534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Jul 31, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
ae0a0ad
constructor frule
matbesancon Apr 16, 2022
d407bf2
frule tested
matbesancon Apr 16, 2022
d5a293a
rrule tests
matbesancon Apr 16, 2022
f9de7b3
logpdf test
matbesancon Apr 17, 2022
3723789
signature for conflict
matbesancon Apr 17, 2022
5e32f04
TestUtils out of Project
matbesancon Apr 24, 2022
0dde72f
ChainRules itself not needed (yet?)
matbesancon Apr 24, 2022
1348792
remove checkarg
matbesancon Apr 24, 2022
90455c8
Update src/multivariate/dirichlet.jl
matbesancon Apr 25, 2022
25a41f3
Update test/dirichlet.jl
matbesancon Apr 25, 2022
89a9346
Update test/dirichlet.jl
matbesancon Apr 25, 2022
96883e8
Update test/dirichlet.jl
matbesancon Apr 25, 2022
ab01122
Update src/multivariate/dirichlet.jl
matbesancon Apr 25, 2022
1d79fec
conflict
matbesancon Apr 25, 2022
bc29c40
Merge branch 'cr-dirichlet' of github.com:JuliaStats/Distributions.jl…
matbesancon Apr 25, 2022
4cc7509
eltype instability
matbesancon Apr 25, 2022
0500772
single loop
matbesancon Apr 25, 2022
d2f832b
fix tests
matbesancon Apr 25, 2022
77ccee6
forward finite diff
matbesancon Apr 25, 2022
feafacd
switch to broadcast
matbesancon Apr 25, 2022
1f06aa6
fix broadcast
matbesancon Apr 25, 2022
e702017
switch off-support value to NaN
matbesancon Apr 25, 2022
475a934
Update src/multivariate/dirichlet.jl
matbesancon Apr 29, 2022
7515e86
Update src/multivariate/dirichlet.jl
matbesancon Apr 29, 2022
1a3fdd9
do not assume inplace
matbesancon May 1, 2022
76cc96a
conflict
matbesancon May 1, 2022
cb4f07e
fixed temp
matbesancon May 23, 2022
9234155
Simplify implementation and tests in #1534 (#1555)
devmotion May 24, 2022
2c7100e
conflict
matbesancon May 24, 2022
7297260
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Jul 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))
length(d::Dirichlet) = length(d.alpha)
mean(d::Dirichlet) = d.alpha .* inv(d.alpha0)
params(d::Dirichlet) = (d.alpha,)
@inline partype(d::Dirichlet{T}) where {T<:Real} = T
@inline partype(::Dirichlet{T}) where {T<:Real} = T

function var(d::Dirichlet)
α0 = d.alpha0
Expand Down Expand Up @@ -375,3 +375,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
elogp = mean_logp(suffstats(Dirichlet, P, w))
fit_dirichlet!(elogp, α; maxiter=maxiter, tol=tol, debug=debug)
end

## Differentiation
function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
Δalpha = ChainRulesCore.unthunk(Δalpha)
∂alpha0 = sum(Δalpha)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i
Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0)
end))
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
return d, t
end

function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
function dirichlet_pullback(d_dir)
d_dir = ChainRulesCore.unthunk(d_dir)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), dalpha
end
return d, dirichlet_pullback
end

function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
lp = _logpdf(d, x)
∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i
end))
∂l = -Δd.lmnB
if !insupport(d, x)
∂α_x = oftype(∂α_x, NaN)
end
return (lp, ∂α_x + ∂l)
end

function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
y = _logpdf(d, x)
function Dirichlet_logpdf_pullback(dy)
∂alpha = xlogy.(dy, x)
∂l = -dy
∂x = dy * (d.alpha .-1) ./ x
∂alpha0 = 0.0
if !isfinite(y)
∂alpha .= NaN
∂l = oftype(∂l, NaN)
∂x .= NaN
∂alpha0 = NaN
end
backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l)
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
end
return (y, Dirichlet_logpdf_pullback)
end
54 changes: 53 additions & 1 deletion test/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

using Distributions
using Test, Random, LinearAlgebra

using ChainRulesCore
using ChainRulesTestUtils
using FiniteDifferences

Random.seed!(34567)

Expand Down Expand Up @@ -127,3 +129,53 @@ end
@test entropy(Dirichlet(N, 1)) ≈ -loggamma(N)
@test entropy(Dirichlet(ones(N))) ≈ -loggamma(N)
end

@testset "Dirichlet differentiation $n" for n in (2, 10)
alpha = rand(n)
Δalpha = randn(n)
d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1))
ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha; fdm=FiniteDifferences.forward_fdm(5, 1))
x = rand(n)
x ./= sum(x)
Δx = 0.05 * rand(n)
Δx .-= mean(Δx)
# such that x ∈ Δ, x + Δx ∈ Δ
ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx)
@testset "finite diff f/r-rule logpdf" begin
for _ in 1:10
x = rand(n)
x ./= sum(x)
Δx = 0.005 * rand(n)
Δx .-= mean(Δx)
if insupport(d, x + Δx) && insupport(d, x - Δx)
y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x)
yf, Δy = ChainRulesCore.frule(
(
ChainRulesCore.NoTangent(),
map(zero, ChainRulesTestUtils.rand_tangent(d)),
Δx,
),
Distributions._logpdf,
d, x,
)
y2 = Distributions._logpdf(d, x + Δx)
y1 = Distributions._logpdf(d, x - Δx)
@test isfinite(y)
@test y == yf
@test Δy ≈ y2 - y atol=5e-3
_, ∂d, ∂x = pullback(1.0)
@test y2 - y1 ≈ dot(2Δx, ∂x) atol=5e-3 rtol=1e-6
# mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet
Δalpha = 0.03 * rand(n)
Δalpha .-= mean(Δalpha)
@assert all(>=(0), alpha + Δalpha)
d.alpha .+= Δalpha
ya = Distributions._logpdf(d, x)
# resetting alpha
d.alpha .-= Δalpha
@test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=5e-5 rtol=1e-6
end
end
end
end