diff --git a/src/multivariate/dirichlet.jl b/src/multivariate/dirichlet.jl index 77ff5dd0b0..d77d4f5d0d 100644 --- a/src/multivariate/dirichlet.jl +++ b/src/multivariate/dirichlet.jl @@ -377,58 +377,60 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64}, end ## Differentiation -function ChainRulesCore.frule((_, Δalpha), ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} +function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) - Δalpha = ChainRulesCore.unthunk(Δalpha) ∂alpha0 = sum(Δalpha) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) - ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalpha_i, alpha_i - Δalpha_i * (SpecialFunctions.digamma(alpha_i) - digamma_alpha0) + ∂lmnB = sum(Broadcast.instantiate(Broadcast.broadcasted(Δalpha, alpha) do Δalphai, alphai + Δalphai * (SpecialFunctions.digamma(alphai) - digamma_alpha0) end)) - backing = (alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) - t = ChainRulesCore.Tangent{typeof(d), NamedTuple{(:alpha, :alpha0, :lmnB), Tuple{typeof(alpha), typeof(d.alpha0), typeof(d.lmnB)}}}(backing) - return d, t + Δd = ChainRulesCore.Tangent{typeof(d)}(; alpha=Δalpha, alpha0=∂alpha0, lmnB=∂lmnB) + return d, Δd end function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) - function dirichlet_pullback(d_dir) - d_dir = ChainRulesCore.unthunk(d_dir) - digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) - dalpha = d_dir.alpha .+ d_dir.alpha0 .+ d_dir.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0) - return ChainRulesCore.NoTangent(), dalpha + digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) + function Dirichlet_pullback(_Δd) + Δd = ChainRulesCore.unthunk(_Δd) + Δalpha = Δd.alpha .+ Δd.alpha0 .+ Δd.lmnB .* (SpecialFunctions.digamma.(alpha) .- digamma_alpha0) + return ChainRulesCore.NoTangent(), Δalpha end - return d, dirichlet_pullback + return d, Dirichlet_pullback end -function ChainRulesCore.frule((_, Δd, Δx), ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) - lp = _logpdf(d, x) - ∂α_x = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalpha_i, Δx_i, alpha_i, x_i - xlogy(Δalpha_i, x_i) + (alpha_i - 1) * Δx_i / x_i +function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) + Ω = _logpdf(d, x) + ∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi + xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi end)) - ∂l = -Δd.lmnB - if !insupport(d, x) - ∂α_x = oftype(∂α_x, NaN) + ∂lmnB = -Δd.lmnB + ΔΩ = ∂alpha + ∂lmnB + if !isfinite(Ω) + ΔΩ = oftype(ΔΩ, NaN) end - return (lp, ∂α_x + ∂l) + return Ω, ΔΩ end -function ChainRulesCore.rrule(::typeof(_logpdf), d::Dirichlet, x::AbstractVector{<:Real}) - y = _logpdf(d, x) - function Dirichlet_logpdf_pullback(dy) - ∂alpha = xlogy.(dy, x) - ∂l = -dy - ∂x = dy .* (d.alpha .-1) ./ x - ∂alpha0 = sum(∂alpha) - if !isfinite(y) - ∂alpha = oftype(eltype(∂alpha), NaN) * ∂alpha - ∂l = oftype(∂l, NaN) - ∂x = oftype(eltype(∂x), NaN) * ∂x - ∂alpha0 = oftype(eltype(∂alpha), NaN) - end - backing = (alpha = ∂alpha, alpha0 = ∂alpha0, lmnB=∂l) - ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) - return (ChainRulesCore.NoTangent(), ∂d, ∂x) +function ChainRulesCore.rrule(::typeof(_logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet} + Ω = _logpdf(d, x) + isfinite_Ω = isfinite(Ω) + alpha = d.alpha + function _logpdf_Dirichlet_pullback(_ΔΩ) + ΔΩ = ChainRulesCore.unthunk(_ΔΩ) + ∂alpha = _logpdf_Dirichlet_∂alphai.(x, ΔΩ, isfinite_Ω) + ∂lmnB = isfinite_Ω ? -float(ΔΩ) : oftype(float(ΔΩ), NaN) + Δd = ChainRulesCore.Tangent{T}(; alpha=∂alpha, lmnB=∂lmnB) + Δx = _logpdf_Dirichlet_Δxi.(ΔΩ, alpha, x, isfinite_Ω) + return ChainRulesCore.NoTangent(), Δd, Δx end - return (y, Dirichlet_logpdf_pullback) + return Ω, _logpdf_Dirichlet_pullback +end +function _logpdf_Dirichlet_∂alphai(xi, ΔΩi, isfinite::Bool) + ∂alphai = xlogy.(ΔΩi, xi) + return isfinite ? ∂alphai : oftype(∂alphai, NaN) +end +function _logpdf_Dirichlet_Δxi(ΔΩi, alphai, xi, isfinite::Bool) + Δxi = ΔΩi * (alphai - 1) / xi + return isfinite ? Δxi : oftype(Δxi, NaN) end diff --git a/test/dirichlet.jl b/test/dirichlet.jl index 98825a7ba1..78de162dca 100644 --- a/test/dirichlet.jl +++ b/test/dirichlet.jl @@ -130,52 +130,31 @@ end @test entropy(Dirichlet(ones(N))) ≈ -loggamma(N) end -@testset "Dirichlet differentiation $n" for n in (2, 10) +@testset "Dirichlet: ChainRules (length=$n)" for n in (2, 10) alpha = rand(n) - Δalpha = randn(n) - d, ∂d = @inferred ChainRulesCore.frule((nothing, Δalpha), Dirichlet, alpha) - ChainRulesTestUtils.test_frule(Dirichlet ⊢ ChainRulesCore.NoTangent(), alpha ⊢ Δalpha; fdm=FiniteDifferences.forward_fdm(5, 1)) - ChainRulesTestUtils.test_rrule(Dirichlet{Float64}, alpha; fdm=FiniteDifferences.forward_fdm(5, 1)) - x = rand(n) - x ./= sum(x) - Δx = 0.05 * rand(n) - Δx .-= mean(Δx) - # such that x ∈ Δ, x + Δx ∈ Δ - ChainRulesTestUtils.test_frule(Distributions._logpdf, d, x ⊢ Δx, fdm=FiniteDifferences.forward_fdm(5, 1)) - @testset "finite diff f/r-rule logpdf" begin - for _ in 1:10 - x = rand(n) - x ./= sum(x) - Δx = 0.005 * rand(n) - Δx .-= mean(Δx) - if insupport(d, x + Δx) && insupport(d, x - Δx) - y, pullback = ChainRulesCore.rrule(Distributions._logpdf, d, x) - yf, Δy = ChainRulesCore.frule( - ( - ChainRulesCore.NoTangent(), - map(zero, ChainRulesTestUtils.rand_tangent(d)), - Δx, - ), - Distributions._logpdf, - d, x, - ) - y2 = Distributions._logpdf(d, x + Δx) - y1 = Distributions._logpdf(d, x - Δx) - @test isfinite(y) - @test y == yf - @test Δy ≈ y2 - y atol=5e-3 - _, ∂d, ∂x = pullback(1.0) - @test y2 - y1 ≈ dot(2Δx, ∂x) atol=5e-3 rtol=1e-6 - # mutating alpha only to compute a new y, changing only this term and not the others in Dirichlet - Δalpha = 0.03 * rand(n) - Δalpha .-= mean(Δalpha) - @assert all(>=(0), alpha + Δalpha) - d.alpha .+= Δalpha - ya = Distributions._logpdf(d, x) - # resetting alpha - d.alpha .-= Δalpha - @test ya - y ≈ dot(Δalpha, ∂d.alpha) atol=1e-6 rtol=1e-6 - end + d = Dirichlet(alpha) + + @testset "constructor $T" for T in (Dirichlet, Dirichlet{Float64}) + # Avoid issues with finite differencing if values in `alpha` become negative or zero + # by using forward differencing + test_frule(T, alpha; fdm=forward_fdm(5, 1)) + test_rrule(T, alpha; fdm=forward_fdm(5, 1)) + end + + @testset "_logpdf" begin + # `x1` is in the support, `x2` isn't + x1 = rand(n) + x1 ./= sum(x1) + x2 = x1 .+ 1 + + # Use special finite differencing method that tries to avoid moving outside of the + # support by limiting the range of the points around the input that are evaluated + fdm = central_fdm(5, 1; max_range=1e-9) + + for x in (x1, x2) + # We have to adjust the tolerance since the finite differencing method is rough + test_frule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true) + test_rrule(Distributions._logpdf, d, x; fdm=fdm, rtol=1e-5, nans=true) end end end