Skip to content

Simplify implementation and tests in #1534 #1555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 40 additions & 38 deletions src/multivariate/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,58 +377,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
end

## Differentiation
function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
Δalpha = ChainRulesCore.unthunk(Δalpha)
∂alpha0 = sum(Δalpha)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i
Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0)
∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai
Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0)
end))
backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing)
return d, t
Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB)
return d, Δd
end

function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}}
d = DT(alpha; check_args=check_args)
function dirichlet_pullback(d_dir)
d_dir = ChainRulesCore.unthunk(d_dir)
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), dalpha
digamma_alpha0 = SpecialFunctions.digamma(d.alpha0)
function Dirichlet_pullback(_Δd)
Δd = ChainRulesCore.unthunk(_Δd)
Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0)
return ChainRulesCore.NoTangent(), Δalpha
end
return d, dirichlet_pullback
return d, Dirichlet_pullback
end

function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
lp = _logpdf(d, x)
α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i
xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
Ω = _logpdf(d, x)
alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi
xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi
end))
∂l = -Δd.lmnB
if !insupport(d, x)
∂α_x = oftype(∂α_x, NaN)
∂lmnB = -Δd.lmnB
ΔΩ = ∂alpha + ∂lmnB
if !isfinite(Ω)
ΔΩ = oftype(ΔΩ, NaN)
end
return (lp, ∂α_x + ∂l)
return Ω, ΔΩ
end

function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real})
y = _logpdf(d, x)
function Dirichlet_logpdf_pullback(dy)
∂alpha = xlogy.(dy, x)
∂l = -dy
∂x = dy .* (d.alpha .-1) ./ x
∂alpha0 = sum(∂alpha)
if !isfinite(y)
∂alpha = oftype(eltype(∂alpha), NaN) * ∂alpha
∂l = oftype(∂l, NaN)
∂x = oftype(eltype(∂x), NaN) * ∂x
∂alpha0 = oftype(eltype(∂alpha), NaN)
end
backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l)
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet}
Ω = _logpdf(d, x)
isfinite_Ω = isfinite(Ω)
alpha = d.alpha
function _logpdf_Dirichlet_pullback(_ΔΩ)
ΔΩ = ChainRulesCore.unthunk(_ΔΩ)
∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω)
∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN)
Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB)
Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω)
return ChainRulesCore.NoTangent(), Δd, Δx
end
return (y, Dirichlet_logpdf_pullback)
return Ω, _logpdf_Dirichlet_pullback
end
function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool)
∂alphai = xlogy.(ΔΩi, xi)
return isfinite ? ∂alphai : oftype(∂alphai, NaN)
end
function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool)
Δxi = ΔΩi * (alphai - 1) / xi
return isfinite ? Δxi : oftype(Δxi, NaN)
end
69 changes: 24 additions & 45 deletions test/dirichlet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,52 +130,31 @@ end
@test entropy(Dirichlet(ones(N))) ≈ -loggamma(N)
end

@testset "Dirichlet differentiation $n" for n in (2, 10)
@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10)
alpha = rand(n)
Δalpha = randn(n)
d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha)
ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1))
ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha; fdm=FiniteDifferences.forward_fdm(5, 1))
x = rand(n)
x ./= sum(x)
Δx = 0.05 * rand(n)
Δx .-= mean(Δx)
# such that x ∈ Δ, x + Δx ∈ Δ
ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx, fdm=FiniteDifferences.forward_fdm(5, 1))
@testset "finite diff f/r-rule logpdf" begin
for _ in 1:10
x = rand(n)
x ./= sum(x)
Δx = 0.005 * rand(n)
Δx .-= mean(Δx)
if insupport(d, x + Δx) && insupport(d, x - Δx)
y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x)
yf, Δy = ChainRulesCore.frule(
(
ChainRulesCore.NoTangent(),
map(zero, ChainRulesTestUtils.rand_tangent(d)),
Δx,
),
Distributions._logpdf,
d, x,
)
y2 = Distributions._logpdf(d, x + Δx)
y1 = Distributions._logpdf(d, x - Δx)
@test isfinite(y)
@test y == yf
@test Δy ≈ y2 - y atol=5e-3
_, ∂d, ∂x = pullback(1.0)
@test y2 - y1 ≈ dot(2Δx, ∂x) atol=5e-3 rtol=1e-6
# mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet
Δalpha = 0.03 * rand(n)
Δalpha .-= mean(Δalpha)
@assert all(>=(0), alpha + Δalpha)
d.alpha .+= Δalpha
ya = Distributions._logpdf(d, x)
# resetting alpha
d.alpha .-= Δalpha
@test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=1e-6 rtol=1e-6
end
d = Dirichlet(alpha)

@testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64})
# Avoid issues with finite differencing if values in `alpha` become negative or zero
# by using forward differencing
test_frule(T, alpha; fdm=forward_fdm(5, 1))
test_rrule(T, alpha; fdm=forward_fdm(5, 1))
end

@testset "_logpdf" begin
# `x1` is in the support, `x2` isn't
x1 = rand(n)
x1 ./= sum(x1)
x2 = x1 .+ 1

# Use special finite differencing method that tries to avoid moving outside of the
# support by limiting the range of the points around the input that are evaluated
fdm = central_fdm(5, 1; max_range=1e-9)

for x in (x1, x2)
# We have to adjust the tolerance since the finite differencing method is rough
test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true)
end
end
end