@@ -24,6 +24,8 @@ export iid
2424export UniformCategorical, Dirac # specific distributions
2525export BetaGeometric, DirCat, symdircat
2626export DictCond # simple conditional distribution
27+ export Conditioned, condition, mergetraces, logdensity
28+ export montecarlo, MonteCarloMethod, LikelihoodWeighting, likelihood_weighting
2729
2830using SpecialFunctions: digamma, logbeta
2931using LogExpFunctions: logsumexp, logaddexp
@@ -32,7 +34,7 @@ using Distributions: Beta, Binomial, BetaBinomial, Geometric
3234using Random: AbstractRNG, default_rng
3335using MacroTools: @capture , splitdef, combinedef, postwalk, prewalk
3436using Accessors: insert, PropertyLens, @set
35- using StaticArrays: @SVector
37+ using StaticArrays: @SVector , sacollect
3638
3739import Base: show, rand
3840import Distributions: logpdf, insupport
@@ -48,7 +50,7 @@ struct ProbProg{NAME, A, KW}
4850 kwargs:: KW
4951end
5052
51- show (io:: IO , :: ProbProg{NAME} ) where NAME = print (io, " ProbProg{$NAME }(...)" )
53+ show (io:: IO , :: ProbProg{NAME} ) where NAME = print (io, " ProbProg{: $NAME }(...)" )
5254recover_trace (:: ProbProg , trace) = trace
5355
5456function logpdf (model:: ProbProg , x)
@@ -120,7 +122,7 @@ For example, `heads ~ Bernoulli(0.5)` is transformed into
120122The lens is used as getter and setter by the interpreters of the program.
121123"""
122124macro probprog (ex)
123- interpreter (i) = Symbol (' i ' )
125+ interpreter (i) = Symbol (" interpreter " )
124126 i = 0
125127
126128 function rewrite_return_expr (ex)
418420(cond:: DictCond )(x) = cond. dists[x]
419421DictCond (dists... ) = DictCond (Dict (dists... ))
420422
423+ # ################################
424+ # ## Conditioned distributions ###
425+ # ################################
426+
427+ struct Conditioned{D, C}
428+ joint :: D # joint distribution
429+ conds :: C # values conditioned on (as a trace)
430+ end
431+
432+ condition (joint; on) = Conditioned (joint, on)
433+ function rand (rng:: AbstractRNG , cond:: Conditioned{<:ProbProg} )
434+ interpret (cond. joint, RandTrace (rng, cond. conds)). interpreter. trace
435+ end
436+
437+ logdensity (dist, x) = logpdf (dist, x)
438+
439+ function logdensity (cond:: Conditioned , trace)
440+ logdensity (cond. joint, mergetraces (cond. conds, trace))
441+ end
442+
443+ mergetraces (t1:: NamedTuple , t2:: NamedTuple ) = (; t1... , t2... )
444+
445+ # ########################
446+ # ## Inference Methods ###
447+ # ########################
448+
449+ abstract type MonteCarloMethod end
450+
451+ function montecarlo (f, dist, mcm:: MonteCarloMethod , rng:: AbstractRNG = default_rng ())
452+ montecarlo (dist, mcm, rng)(f)
453+ end
454+
455+ function montecarlo (dist, mcm:: MonteCarloMethod , rng:: AbstractRNG = default_rng ())
456+ samples, logweights = weighted_samples (mcm, dist, rng)
457+ function 𝔼 (f)
458+ # Dispatch on whether f is multivariate.
459+ function calc_expectation (:: Type{<:AbstractArray{T}} ) where T
460+ weighted = Matrix {T} (undef, length (y), length (samples))
461+ for j in eachindex (samples)
462+ weighted[:, j] = log .(samples[j]) .+ logweights[j]
463+ end
464+ exp (logsumexp (weighted, dims= 2 ) .- logsumexp (logweights))
465+ end
466+ function calc_expectation (:: Type{<:Any} )
467+ weighted = map (samples, logweights) do x, logweight
468+ log (f (x)) + logweight
469+ end
470+ exp (logsumexp (weighted) - logsumexp (logweights))
471+ end
472+ calc_expectation (typeof (f (first (samples))))
473+ end
474+ end
475+
476+ struct LikelihoodWeighting <: MonteCarloMethod
477+ num_samples :: Int
478+ end
479+
480+ function weighted_samples (mcm:: LikelihoodWeighting , dist, rng)
481+ likelihood_weighting (dist, mcm. num_samples, rng)
482+ end
483+
484+ function likelihood_weighting (model:: Conditioned{<:ProbProg} , num_samples, rng:: AbstractRNG = default_rng ())
485+ traces = rand (rng, iid (model, num_samples))
486+ logweights = logdensity .((model,), traces)
487+ traces, logweights
488+ end
489+
421490# #############################################
422491# ## Interpreters of probabilistic programs ###
423492# #############################################
@@ -495,22 +564,25 @@ function sample(i::EvalTrace, dist, lens)
495564end
496565
497566"""
498- RandTrace([rng])
567+ RandTrace([rng], [conds] )
499568
500569Interpreter that runs a probabilistic program and traces the random choices
501570made in the sample statements in a named tuple.
502571"""
503- struct RandTrace{R,T } <: Interpreter
572+ struct RandTrace{R<: AbstractRNG ,T <: NamedTuple } <: Interpreter
504573 rng:: R
505574 trace:: T
506575end
507576
508577RandTrace (rng:: AbstractRNG ) = RandTrace (rng, (;))
509- RandTrace () = RandTrace (default_rng ())
510578
511- function sample (i:: RandTrace , dist, lens)
512- x = rand (i. rng, dist)
513- return RandTrace (i. rng, insert (i. trace, lens, x)), x
579+ function sample (i:: RandTrace , dist, lens:: PropertyLens{name} ) where name
580+ if hasproperty (i. trace, name)
581+ i, lens (i. trace)
582+ else
583+ x = rand (i. rng, dist)
584+ RandTrace (i. rng, insert (i. trace, lens, x)), x
585+ end
514586end
515587
516588"""
0 commit comments