Skip to content

Commit c061744

Browse files
committed
initial API idea monte carlo
1 parent ccd0a38 commit c061744

File tree

3 files changed

+126
-20
lines changed

3 files changed

+126
-20
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1111
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1212
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1313
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
14+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1415
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1516

1617
[compat]

src/SimpleProbabilisticPrograms.jl

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ export iid
2424
export UniformCategorical, Dirac # specific distributions
2525
export BetaGeometric, DirCat, symdircat
2626
export DictCond # simple conditional distribution
27+
export Conditioned, condition, mergetraces, logdensity
28+
export montecarlo, MonteCarloMethod, LikelihoodWeighting, likelihood_weighting
2729

2830
using SpecialFunctions: digamma, logbeta
2931
using LogExpFunctions: logsumexp, logaddexp
@@ -32,7 +34,7 @@ using Distributions: Beta, Binomial, BetaBinomial, Geometric
3234
using Random: AbstractRNG, default_rng
3335
using MacroTools: @capture, splitdef, combinedef, postwalk, prewalk
3436
using Accessors: insert, PropertyLens, @set
35-
using StaticArrays: @SVector
37+
using StaticArrays: @SVector, sacollect
3638

3739
import Base: show, rand
3840
import Distributions: logpdf, insupport
@@ -48,7 +50,7 @@ struct ProbProg{NAME, A, KW}
4850
kwargs::KW
4951
end
5052

51-
show(io::IO, ::ProbProg{NAME}) where NAME = print(io, "ProbProg{$NAME}(...)")
53+
show(io::IO, ::ProbProg{NAME}) where NAME = print(io, "ProbProg{:$NAME}(...)")
5254
recover_trace(::ProbProg, trace) = trace
5355

5456
function logpdf(model::ProbProg, x)
@@ -120,7 +122,7 @@ For example, `heads ~ Bernoulli(0.5)` is transformed into
120122
The lens is used as getter and setter by the interpreters of the program.
121123
"""
122124
macro probprog(ex)
123-
interpreter(i) = Symbol('i')
125+
interpreter(i) = Symbol("interpreter")
124126
i = 0
125127

126128
function rewrite_return_expr(ex)
@@ -418,6 +420,73 @@ end
418420
(cond::DictCond)(x) = cond.dists[x]
419421
DictCond(dists...) = DictCond(Dict(dists...))
420422

423+
#################################
424+
### Conditioned distributions ###
425+
#################################
426+
427+
struct Conditioned{D, C}
428+
joint :: D # joint distribution
429+
conds :: C # values conditioned on (as a trace)
430+
end
431+
432+
condition(joint; on) = Conditioned(joint, on)
433+
function rand(rng::AbstractRNG, cond::Conditioned{<:ProbProg})
434+
interpret(cond.joint, RandTrace(rng, cond.conds)).interpreter.trace
435+
end
436+
437+
logdensity(dist, x) = logpdf(dist, x)
438+
439+
function logdensity(cond::Conditioned, trace)
440+
logdensity(cond.joint, mergetraces(cond.conds, trace))
441+
end
442+
443+
mergetraces(t1::NamedTuple, t2::NamedTuple) = (; t1..., t2...)
444+
445+
#########################
446+
### Inference Methods ###
447+
#########################
448+
449+
abstract type MonteCarloMethod end
450+
451+
function montecarlo(f, dist, mcm::MonteCarloMethod, rng::AbstractRNG=default_rng())
452+
montecarlo(dist, mcm, rng)(f)
453+
end
454+
455+
function montecarlo(dist, mcm::MonteCarloMethod, rng::AbstractRNG=default_rng())
456+
samples, logweights = weighted_samples(mcm, dist, rng)
457+
function 𝔼(f)
458+
# Dispatch on whether f is multivariate.
459+
function calc_expectation(::Type{<:AbstractArray{T}}) where T
460+
weighted = Matrix{T}(undef, length(y), length(samples))
461+
for j in eachindex(samples)
462+
weighted[:, j] = log.(samples[j]) .+ logweights[j]
463+
end
464+
exp(logsumexp(weighted, dims=2) .- logsumexp(logweights))
465+
end
466+
function calc_expectation(::Type{<:Any})
467+
weighted = map(samples, logweights) do x, logweight
468+
log(f(x)) + logweight
469+
end
470+
exp(logsumexp(weighted) - logsumexp(logweights))
471+
end
472+
calc_expectation(typeof(f(first(samples))))
473+
end
474+
end
475+
476+
struct LikelihoodWeighting <: MonteCarloMethod
477+
num_samples :: Int
478+
end
479+
480+
function weighted_samples(mcm::LikelihoodWeighting, dist, rng)
481+
likelihood_weighting(dist, mcm.num_samples, rng)
482+
end
483+
484+
function likelihood_weighting(model::Conditioned{<:ProbProg}, num_samples, rng::AbstractRNG=default_rng())
485+
traces = rand(rng, iid(model, num_samples))
486+
logweights = logdensity.((model,), traces)
487+
traces, logweights
488+
end
489+
421490
##############################################
422491
### Interpreters of probabilistic programs ###
423492
##############################################
@@ -495,22 +564,25 @@ function sample(i::EvalTrace, dist, lens)
495564
end
496565

497566
"""
498-
RandTrace([rng])
567+
RandTrace([rng], [conds])
499568
500569
Interpreter that runs a probabilistic program and traces the random choices
501570
made in the sample statements in a named tuple.
502571
"""
503-
struct RandTrace{R,T} <: Interpreter
572+
struct RandTrace{R<:AbstractRNG,T<:NamedTuple} <: Interpreter
504573
rng::R
505574
trace::T
506575
end
507576

508577
RandTrace(rng::AbstractRNG) = RandTrace(rng, (;))
509-
RandTrace() = RandTrace(default_rng())
510578

511-
function sample(i::RandTrace, dist, lens)
512-
x = rand(i.rng, dist)
513-
return RandTrace(i.rng, insert(i.trace, lens, x)), x
579+
function sample(i::RandTrace, dist, lens::PropertyLens{name}) where name
580+
if hasproperty(i.trace, name)
581+
i, lens(i.trace)
582+
else
583+
x = rand(i.rng, dist)
584+
RandTrace(i.rng, insert(i.trace, lens, x)), x
585+
end
514586
end
515587

516588
"""

test/runtests.jl

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using SimpleProbabilisticPrograms
22
using Test
3-
using Random
3+
4+
using Random: MersenneTwister
45
using Distributions: Beta, Bernoulli
6+
using Statistics: mean
57

68
@testset "basic tests" begin
79
@probprog function beta_bernoulli_model(a, b, n)
@@ -10,8 +12,8 @@ using Distributions: Beta, Bernoulli
1012
return (; bias, coins)
1113
end
1214
model = beta_bernoulli_model(3, 4, 10)
13-
trace = rand(Random.MersenneTwister(42), model)
14-
trace_static = rand(Random.MersenneTwister(42), beta_bernoulli_model(3, 4, Val(10)))
15+
trace = rand(MersenneTwister(42), model)
16+
trace_static = rand(MersenneTwister(42), beta_bernoulli_model(3, 4, Val(10)))
1517
@test trace_static == trace
1618
@test -Inf < logpdf(model, trace) < 0
1719
@test insupport(model, trace)
@@ -109,12 +111,43 @@ end
109111
@test rand(cond('b')) in 10:15
110112
end
111113

112-
using Distributions: Beta, Bernoulli
113-
@probprog function beta_bernoulli_model(a, b, n)
114-
bias ~ Beta(a, b)
115-
coins ~ iid(Bernoulli(bias), n)
116-
return (; bias, coins)
114+
@testset "Monte Carlo" begin
115+
@probprog function bbm(a, b, n) # beta bernoulli model
116+
bias ~ Beta(a, b)
117+
coins ~ iid(Bernoulli(bias), n)
118+
return (; bias, coins)
119+
end
120+
121+
N = 1000
122+
model = bbm(1, 1, N)
123+
data = rand(iid(Bernoulli(0.4), N))
124+
E = montecarlo(condition(model, on=(; coins=data)), LikelihoodWeighting(10_000))
125+
@test E(trace -> trace.bias) mean(data) atol=0.01
126+
127+
# TODO: Test montecarlo with multivariate function
117128
end
118-
@time model = beta_bernoulli_model(3, 4, 1000)
119-
@time trace = rand(model)
120-
@time logpdf(model, trace)
129+
130+
131+
# t1 = (a=1, b=(c=2, d=3, e=4))
132+
# t2 = (b=(c=5, d=6), f=7)
133+
# @test mergetraces(t1, t2) == (a=1, b=(c=5, d=6, e=4), f=7)
134+
135+
# T1 = typeof(t1)
136+
# T2 = typeof(t2)
137+
# fieldnames(T1)
138+
# fieldtypes(T1)
139+
# fieldnames(T2)
140+
# fieldtypes(T2)
141+
142+
# all_names = union(fieldnames(T1), fieldnames(T2))
143+
# map(all_names) do n
144+
# if n in fieldnames(T1)
145+
# if n in fieldnames(T2)
146+
# if
147+
# else
148+
# :($n = t1.$n)
149+
# end
150+
# else
151+
# :($n = t2.$n)
152+
# end
153+
# end

0 commit comments

Comments
 (0)