Skip to content

Commit ef98cff

Browse files
committed
Add weight correction for stratified sampling.
1 parent 4d5f28b commit ef98cff

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

src/initialize.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ function pf_initialize(
9797
V = dynamic ? Trace : U # Determine trace type for particle filter
9898
traces = Vector{V}(undef, n_particles)
9999
log_weights = Vector{Float64}(undef, n_particles)
100+
n_strata = length(strata)
100101
# Generate traces in a stratified manner
101102
stratified_map!(n_particles, strata; layout=layout) do i, stratum
102103
constraints = merge(stratum, observations)
103104
(traces[i], log_weights[i]) = generate(model, model_args, constraints)
105+
log_weights[i] += log(n_strata)
104106
end
105107
return ParticleFilterState{V}(traces, Vector{V}(undef, n_particles),
106108
log_weights, 0., collect(1:n_particles))
@@ -115,11 +117,12 @@ function pf_initialize(
115117
V = dynamic ? Trace : U # Determine trace type for particle filter
116118
traces = Vector{V}(undef, n_particles)
117119
log_weights = Vector{Float64}(undef, n_particles)
120+
n_strata = length(strata)
118121
stratified_map!(n_particles, strata; layout=layout) do i, stratum
119122
(prop_choices, prop_weight, _) = propose(proposal, proposal_args)
120123
constraints = merge(stratum, observations, prop_choices)
121124
(traces[i], model_weight) = generate(model, model_args, constraints)
122-
log_weights[i] = model_weight - prop_weight
125+
log_weights[i] = model_weight - prop_weight + log(n_strata)
123126
end
124127
return ParticleFilterState{V}(traces, Vector{V}(undef, n_particles),
125128
log_weights, 0., collect(1:n_particles))

src/update.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,15 @@ function pf_update!(state::ParticleFilterView, new_args::Tuple,
195195
layout=:interleaved)
196196
# Update traces in a stratified manner
197197
n_particles = length(state.traces)
198+
n_strata = length(strata)
198199
stratified_map!(n_particles, strata; layout=layout) do i, stratum
199200
constraints = merge(stratum, observations)
200201
state.new_traces[i], increment, _, discard =
201202
update(state.traces[i], new_args, argdiffs, constraints)
202203
if !isempty(discard)
203204
error("Choices were updated or deleted: $discard")
204205
end
205-
state.log_weights[i] += increment
206+
state.log_weights[i] += increment + log(n_strata)
206207
end
207208
update_refs!(state)
208209
return state
@@ -215,11 +216,12 @@ function pf_update!(state::ParticleFilterView, translator::TraceTranslator,
215216
observations = translator.new_observations
216217
# Update traces in a stratified manner
217218
n_particles = length(state.traces)
219+
n_strata = length(strata)
218220
stratified_map!(n_particles, strata; layout=layout) do i, stratum
219221
translator.new_observations = merge(stratum, observations)
220222
state.new_traces[i], log_weight =
221223
translator(state.traces[i]; translator_args...)
222-
state.log_weights[i] += log_weight
224+
state.log_weights[i] += log_weight + log(n_strata)
223225
end
224226
update_refs!(state)
225227
return state

test/initialize.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
@testset "Initialize with default proposal" begin
44
state = pf_initialize(line_model, (0,), choicemap(), 100)
55
@test all(-2 <= tr[:slope] <= 2 for tr in get_traces(state))
6+
@test all(w 0 for w in get_log_weights(state))
67
state = pf_initialize(line_model, (1,), line_choicemap(1), 100)
78
@test all(tr[:line => 1 => :y] == 0 for tr in get_traces(state))
89
state = pf_initialize(line_model, (10,), line_choicemap(10), 100)
@@ -17,6 +18,7 @@ end
1718
@testset "Initialize with custom proposal" begin
1819
state = pf_initialize(line_model, (0,), choicemap(), line_propose, (0,), 100)
1920
@test all(tr[:slope] == 0 for tr in get_traces(state))
21+
@test all(w log(1/5) for w in get_log_weights(state))
2022
state = pf_initialize(line_model, (1,), line_choicemap(1),
2123
outlier_propose, ([1],), 100)
2224
@test all(tr[:line => 1 => :outlier] == false for tr in get_traces(state))
@@ -38,6 +40,9 @@ end
3840
slope_strata = (slope_choicemap(s) for s in -2:1:2)
3941
observations = line_choicemap(1)
4042
# Test contiguous stratification
43+
state = pf_initialize(line_model, (0,), choicemap(),
44+
slope_strata, 100; layout=:contiguous)
45+
@test all(w 0 for w in get_log_weights(state))
4146
state = pf_initialize(line_model, (1,), observations,
4247
slope_strata, 100; layout=:contiguous)
4348
for (k, slope) in zip([20, 40, 60, 80, 100], -2:1:2)
@@ -46,6 +51,9 @@ end
4651
@test all(tr[:line => 1 => :y] == 0 for tr in traces)
4752
end
4853
# Test interleaved stratification
54+
state = pf_initialize(line_model, (0,), choicemap(),
55+
slope_strata, 100; layout=:interleaved)
56+
@test all(w 0 for w in get_log_weights(state))
4957
state = pf_initialize(line_model, (1,), observations,
5058
slope_strata, 100; layout=:interleaved)
5159
for (k, slope) in zip([1, 2, 3, 4, 5], -2:1:2)
@@ -66,6 +74,9 @@ end
6674
@test all(tr[:slope] == slope for tr in traces)
6775
@test all(tr[:line => 1 => :outlier] == false for tr in traces)
6876
@test all(tr[:line => 1 => :y] == 0 for tr in traces)
77+
expected_w = (logpdf(bernoulli, false, 0.1) +
78+
logpdf(normal, 0.0, slope, 1.0))
79+
@test all(w expected_w for w in get_log_weights(state[(k-20+1):k]))
6980
end
7081
# Test interleaved stratification
7182
state = pf_initialize(line_model, (1,), observations, slope_strata,
@@ -75,6 +86,9 @@ end
7586
@test all(tr[:slope] == slope for tr in traces)
7687
@test all(tr[:line => 1 => :outlier] == false for tr in traces)
7788
@test all(tr[:line => 1 => :y] == 0 for tr in traces)
89+
expected_w = (logpdf(bernoulli, false, 0.1) +
90+
logpdf(normal, 0.0, slope, 1.0))
91+
@test all(w expected_w for w in get_log_weights(state[k:5:100]))
7892
end
7993
end
8094

test/update.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
@testset "Update with default proposal" begin
44
state = pf_initialize(line_model, (0,), choicemap(), 100)
5-
state = pf_update!(state, (10,), (UnknownChange(),), line_choicemap(10))
6-
@test all(tr[:line => 10 => :y] == 0 for tr in get_traces(state))
7-
@test all(w != 0 for w in get_log_weights(state))
5+
state = pf_update!(state, (1,), (UnknownChange(),), line_choicemap(1))
6+
@test all(tr[:line => 1 => :y] == 0 for tr in get_traces(state))
7+
outliers = [tr[:line => 1 => :outlier] for tr in get_traces(state)]
8+
expected_ws = [logpdf(normal, 0.0, tr[:slope], o ? 10.0 : 1.0)
9+
for (o, tr) in zip(outliers, get_traces(state))]
10+
@test all(get_log_weights(state) .≈ expected_ws)
811
end
912

1013
@testset "Update with stratification" begin
@@ -17,17 +20,23 @@ end
1720
for (k, val) in zip([50, 100], [false, true])
1821
traces = get_traces(state[(k-50+1):k])
1922
@test all(tr[:line => 1 => :outlier] == val for tr in traces)
23+
std = val ? 10.0 : 1.0
24+
expected_ws = [(logpdf(bernoulli, val, 0.1) + log(2) +
25+
logpdf(normal, 0.0, tr[:slope], std)) for tr in traces]
26+
@test all(get_log_weights(state[(k-50+1):k]) .≈ expected_ws)
2027
end
21-
@test all(w != 0 for w in get_log_weights(state))
2228
# Test interleaved stratification
2329
state = pf_initialize(line_model, (0,), choicemap(), 100)
2430
state = pf_update!(state, (1,), (UnknownChange(),), observations,
2531
outlier_strata; layout=:interleaved)
2632
for (k, val) in zip(1:2, [false, true])
2733
traces = get_traces(state[k:2:100])
2834
@test all(tr[:line => 1 => :outlier] == val for tr in traces)
35+
std = val ? 10.0 : 1.0
36+
expected_ws = [(logpdf(bernoulli, val, 0.1) + log(2) +
37+
logpdf(normal, 0.0, tr[:slope], std)) for tr in traces]
38+
@test all(get_log_weights(state[k:2:100]) .≈ expected_ws)
2939
end
30-
@test all(w != 0 for w in get_log_weights(state))
3140
end
3241

3342
@gen outlier_propose(tr, idxs) =

0 commit comments

Comments
 (0)