@@ -72,7 +72,7 @@ Base.show(io::IO, d::Dirichlet) = show(io, d, (:alpha,))
72
72
length (d:: Dirichlet ) = length (d. alpha)
73
73
mean (d:: Dirichlet ) = d. alpha .* inv (d. alpha0)
74
74
params (d:: Dirichlet ) = (d. alpha,)
75
- @inline partype (d :: Dirichlet{T} ) where {T<: Real } = T
75
+ @inline partype (:: Dirichlet{T} ) where {T<: Real } = T
76
76
77
77
function var (d:: Dirichlet )
78
78
α0 = d. alpha0
@@ -375,3 +375,62 @@ function fit_mle(::Type{<:Dirichlet}, P::AbstractMatrix{Float64},
375
375
elogp = mean_logp (suffstats (Dirichlet, P, w))
376
376
fit_dirichlet! (elogp, α; maxiter= maxiter, tol= tol, debug= debug)
377
377
end
378
+
379
+ # # Differentiation
380
+ function ChainRulesCore. frule ((_, Δalpha):: Tuple{Any,Any} , :: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
381
+ d = DT (alpha; check_args= check_args)
382
+ ∂alpha0 = sum (Δalpha)
383
+ digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
384
+ ∂lmnB = sum (Broadcast. instantiate (Broadcast. broadcasted (Δalpha, alpha) do Δalphai, alphai
385
+ Δalphai * (SpecialFunctions. digamma (alphai) - digamma_alpha0)
386
+ end ))
387
+ Δd = ChainRulesCore. Tangent {typeof(d)} (; alpha= Δalpha, alpha0= ∂alpha0, lmnB= ∂lmnB)
388
+ return d, Δd
389
+ end
390
+
391
+ function ChainRulesCore. rrule (:: Type{DT} , alpha:: AbstractVector{T} ; check_args:: Bool = true ) where {T <: Real , DT <: Union{Dirichlet{T}, Dirichlet} }
392
+ d = DT (alpha; check_args= check_args)
393
+ digamma_alpha0 = SpecialFunctions. digamma (d. alpha0)
394
+ function Dirichlet_pullback (_Δd)
395
+ Δd = ChainRulesCore. unthunk (_Δd)
396
+ Δalpha = Δd. alpha .+ Δd. alpha0 .+ Δd. lmnB .* (SpecialFunctions. digamma .(alpha) .- digamma_alpha0)
397
+ return ChainRulesCore. NoTangent (), Δalpha
398
+ end
399
+ return d, Dirichlet_pullback
400
+ end
401
+
402
+ function ChainRulesCore. frule ((_, Δd, Δx):: Tuple{Any,Any,Any} , :: typeof (_logpdf), d:: Dirichlet , x:: AbstractVector{<:Real} )
403
+ Ω = _logpdf (d, x)
404
+ ∂alpha = sum (Broadcast. instantiate (Broadcast. broadcasted (Δd. alpha, Δx, d. alpha, x) do Δalphai, Δxi, alphai, xi
405
+ xlogy (Δalphai, xi) + (alphai - 1 ) * Δxi / xi
406
+ end ))
407
+ ∂lmnB = - Δd. lmnB
408
+ ΔΩ = ∂alpha + ∂lmnB
409
+ if ! isfinite (Ω)
410
+ ΔΩ = oftype (ΔΩ, NaN )
411
+ end
412
+ return Ω, ΔΩ
413
+ end
414
+
415
+ function ChainRulesCore. rrule (:: typeof (_logpdf), d:: T , x:: AbstractVector{<:Real} ) where {T<: Dirichlet }
416
+ Ω = _logpdf (d, x)
417
+ isfinite_Ω = isfinite (Ω)
418
+ alpha = d. alpha
419
+ function _logpdf_Dirichlet_pullback (_ΔΩ)
420
+ ΔΩ = ChainRulesCore. unthunk (_ΔΩ)
421
+ ∂alpha = _logpdf_Dirichlet_∂alphai .(x, ΔΩ, isfinite_Ω)
422
+ ∂lmnB = isfinite_Ω ? - float (ΔΩ) : oftype (float (ΔΩ), NaN )
423
+ Δd = ChainRulesCore. Tangent {T} (; alpha= ∂alpha, lmnB= ∂lmnB)
424
+ Δx = _logpdf_Dirichlet_Δxi .(ΔΩ, alpha, x, isfinite_Ω)
425
+ return ChainRulesCore. NoTangent (), Δd, Δx
426
+ end
427
+ return Ω, _logpdf_Dirichlet_pullback
428
+ end
429
+ function _logpdf_Dirichlet_∂alphai (xi, ΔΩi, isfinite:: Bool )
430
+ ∂alphai = xlogy .(ΔΩi, xi)
431
+ return isfinite ? ∂alphai : oftype (∂alphai, NaN )
432
+ end
433
+ function _logpdf_Dirichlet_Δxi (ΔΩi, alphai, xi, isfinite:: Bool )
434
+ Δxi = ΔΩi * (alphai - 1 ) / xi
435
+ return isfinite ? Δxi : oftype (Δxi, NaN )
436
+ end
0 commit comments