Skip to content

Commit 7aa1fdf

Browse files
committed
Support updates and rejuvenation for sub-views of particle filters.
1 parent 0269780 commit 7aa1fdf

File tree

8 files changed

+109
-48
lines changed

8 files changed

+109
-48
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GenParticleFilters"
22
uuid = "56b76ac4-72ef-411e-b419-6d312ed86a6f"
33
authors = ["Xuan <[email protected]>"]
4-
version = "0.1.4"
4+
version = "0.1.5"
55

66
[deps]
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/GenParticleFilters.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ module GenParticleFilters
33
using Gen, Distributions
44
using Gen: ParticleFilterState
55

6-
export ParticleFilterState
6+
export ParticleFilterState, ParticleFilterSubState, ParticleFilterView
77

8+
include("view.jl")
89
include("utils.jl")
910
include("initialize.jl")
1011
include("update.jl")

src/rejuvenate.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ a tuple with a trace as the first return value. `method` specifies the
1212
rejuvenation method: `:move` for MCMC moves without a reweighting step,
1313
and `:reweight` for rejuvenation with a reweighting step.
1414
"""
15-
function pf_rejuvenate!(state::ParticleFilterState, kern, kern_args::Tuple=(),
15+
function pf_rejuvenate!(state::ParticleFilterView, kern, kern_args::Tuple=(),
1616
n_iters::Int=1; method::Symbol=:move)
1717
if method == :move
1818
return pf_move_accept!(state, kern, kern_args, n_iters)
@@ -34,7 +34,7 @@ a tuple `(trace, accept)`, where `trace` is the (potentially) new trace, and
3434
can be supplied with `kern_args`. The kernel is repeatedly applied to each trace
3535
for `n_iters`.
3636
"""
37-
function pf_move_accept!(state::ParticleFilterState,
37+
function pf_move_accept!(state::ParticleFilterView,
3838
kern, kern_args::Tuple=(), n_iters::Int=1)
3939
# Potentially rejuvenate each trace
4040
for (i, trace) in enumerate(state.traces)
@@ -44,10 +44,7 @@ function pf_move_accept!(state::ParticleFilterState,
4444
end
4545
state.new_traces[i] = trace
4646
end
47-
# Swap references
48-
tmp = state.traces
49-
state.traces = state.new_traces
50-
state.new_traces = tmp
47+
update_refs!(state)
5148
return state
5249
end
5350

@@ -66,7 +63,7 @@ accumulated accordingly.
6663
[1] R. A. G. Marques and G. Storvik, "Particle move-reweighting strategies for
6764
online inference," Preprint series. Statistical Research Report, 2013.
6865
"""
69-
function pf_move_reweight!(state::ParticleFilterState,
66+
function pf_move_reweight!(state::ParticleFilterView,
7067
kern, kern_args::Tuple=(), n_iters::Int=1)
7168
# Move and reweight each trace
7269
for (i, trace) in enumerate(state.traces)
@@ -79,10 +76,7 @@ function pf_move_reweight!(state::ParticleFilterState,
7976
state.new_traces[i] = trace
8077
state.log_weights[i] += weight
8178
end
82-
# Swap references
83-
tmp = state.traces
84-
state.traces = state.new_traces
85-
state.new_traces = tmp
79+
update_refs!(state)
8680
return state
8781
end
8882

src/resample.jl

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ function pf_multinomial_resample!(state::ParticleFilterState;
6060
ws = state.log_weights[state.parents] .- log_priorities[state.parents]
6161
state.log_weights = ws .+ (log(n_particles) - logsumexp(ws))
6262
end
63-
# Swap references
64-
tmp = state.traces
65-
state.traces = state.new_traces
66-
state.new_traces = tmp
63+
update_refs!(state)
6764
return state
6865
end
6966

@@ -115,10 +112,7 @@ function pf_residual_resample!(state::ParticleFilterState;
115112
ws = state.log_weights[state.parents] .- log_priorities[state.parents]
116113
state.log_weights = ws .+ (log(n_particles) - logsumexp(ws))
117114
end
118-
# Swap references
119-
tmp = state.traces
120-
state.traces = state.new_traces
121-
state.new_traces = tmp
115+
update_refs!(state)
122116
return state
123117
end
124118

@@ -170,9 +164,6 @@ function pf_stratified_resample!(state::ParticleFilterState;
170164
ws = state.log_weights[state.parents] .- log_priorities[state.parents]
171165
state.log_weights = ws .+ (log(n_particles) - logsumexp(ws))
172166
end
173-
# Swap references
174-
tmp = state.traces
175-
state.traces = state.new_traces
176-
state.new_traces = tmp
167+
update_refs!(state)
177168
return state
178169
end

src/update.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Perform a particle filter update, where the model arguments are adjusted and
99
new observations are conditioned upon. New latent choices are sampled from
1010
the model's default proposal.
1111
"""
12-
function pf_update!(state::ParticleFilterState, new_args::Tuple,
12+
function pf_update!(state::ParticleFilterView, new_args::Tuple,
1313
argdiffs::Tuple, observations::ChoiceMap)
1414
n_particles = length(state.traces)
1515
for i=1:n_particles
@@ -20,10 +20,7 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple,
2020
end
2121
state.log_weights[i] += increment
2222
end
23-
# Swap references
24-
tmp = state.traces
25-
state.traces = state.new_traces
26-
state.new_traces = tmp
23+
update_refs!(state)
2724
return state
2825
end
2926

@@ -51,7 +48,7 @@ that occur in `a` also occur in `b`, and the values at those addresses are
5148
equal. It is an error if no trace `t_new` satisfying the above conditions
5249
exists in the support of the model (with the new arguments).
5350
"""
54-
function pf_update!(state::ParticleFilterState, new_args::Tuple,
51+
function pf_update!(state::ParticleFilterView, new_args::Tuple,
5552
argdiffs::Tuple, observations::ChoiceMap,
5653
proposal::GenerativeFunction, proposal_args::Tuple)
5754
n_particles = length(state.traces)
@@ -66,10 +63,7 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple,
6663
end
6764
state.log_weights[i] += up_weight - prop_weight
6865
end
69-
# Swap references
70-
tmp = state.traces
71-
state.traces = state.new_traces
72-
state.new_traces = tmp
66+
update_refs!(state)
7367
return state
7468
end
7569

@@ -107,7 +101,7 @@ calls to `pf_update!`).
107101
Similar functionality is provided by [`move_reweight`](@ref), except that
108102
`pf_update!` also allows model arguments to be updated.
109103
"""
110-
function pf_update!(state::ParticleFilterState, new_args::Tuple,
104+
function pf_update!(state::ParticleFilterView, new_args::Tuple,
111105
argdiffs::Tuple, observations::ChoiceMap,
112106
fwd_proposal::GenerativeFunction, fwd_args::Tuple,
113107
bwd_proposal::GenerativeFunction, bwd_args::Tuple)
@@ -122,9 +116,6 @@ function pf_update!(state::ParticleFilterState, new_args::Tuple,
122116
assess(bwd_proposal, (state.new_traces[i], bwd_args...), discard)
123117
state.log_weights[i] += up_weight - fwd_weight + bwd_weight
124118
end
125-
# Swap references
126-
tmp = state.traces
127-
state.traces = state.new_traces
128-
state.new_traces = tmp
119+
update_refs!(state)
129120
return state
130121
end

src/utils.jl

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@ export mean, var
66
using Gen: effective_sample_size
77
using Statistics
88

9+
@inline function update_refs!(state::ParticleFilterState)
10+
# Swap references
11+
tmp = state.traces
12+
state.traces = state.new_traces
13+
state.new_traces = tmp
14+
end
15+
16+
@inline function update_refs!(state::ParticleFilterSubState)
17+
state.traces[:] = state.new_traces
18+
end
19+
920
lognorm(v::AbstractVector) = v .- logsumexp(v)
1021

1122
"""
@@ -14,41 +25,41 @@ lognorm(v::AbstractVector) = v .- logsumexp(v)
1425
Return the vector of normalized log weights for the current state,
1526
one for each particle.
1627
"""
17-
get_log_norm_weights(state::ParticleFilterState) = lognorm(state.log_weights)
28+
get_log_norm_weights(state::ParticleFilterView) = lognorm(state.log_weights)
1829

1930
"""
2031
get_norm_weights(state::ParticleFilterState)
2132
2233
Return the vector of normalized weights for the current state,
2334
one for each particle.
2435
"""
25-
get_norm_weights(state::ParticleFilterState) = exp.(get_log_norm_weights(state))
36+
get_norm_weights(state::ParticleFilterView) = exp.(get_log_norm_weights(state))
2637

2738
"""
2839
effective_sample_size(state::ParticleFilterState)
2940
3041
Computes the effective sample size of the particles in the filter.
3142
"""
32-
Gen.effective_sample_size(state::ParticleFilterState) =
43+
Gen.effective_sample_size(state::ParticleFilterView) =
3344
Gen.effective_sample_size(get_log_norm_weights(state))
3445

3546
"""
3647
get_ess(state::ParticleFilterState)
3748
3849
Alias for `effective_sample_size`(@ref). Computes the effective sample size.
3950
"""
40-
get_ess(state::ParticleFilterState) = Gen.effective_sample_size(state)
51+
get_ess(state::ParticleFilterView) = Gen.effective_sample_size(state)
4152

4253
"""
4354
mean(state::ParticleFilterState[, addr])
4455
4556
Returns the weighted empirical mean for a particular trace address `addr`.
4657
If `addr` is not provided, returns the empirical mean of the return value.
4758
"""
48-
Statistics.mean(state::ParticleFilterState, addr) =
59+
Statistics.mean(state::ParticleFilterView, addr) =
4960
sum(get_norm_weights(state) .* getindex.(state.traces, addr))
5061

51-
Statistics.mean(state::ParticleFilterState) =
62+
Statistics.mean(state::ParticleFilterView) =
5263
sum(get_norm_weights(state) .* get_retval.(state.traces))
5364

5465
"""
@@ -57,10 +68,10 @@ Statistics.mean(state::ParticleFilterState) =
5768
Returns the empirical variance for a particular trace address `addr`.
5869
If `addr` is not provided, returns the empirical variance of the return value.
5970
"""
60-
Statistics.var(state::ParticleFilterState, addr) =
71+
Statistics.var(state::ParticleFilterView, addr) =
6172
sum(get_norm_weights(state) .*
6273
(getindex.(state.traces, addr) .- mean(state, addr)).^2)
6374

64-
Statistics.var(state::ParticleFilterState) =
75+
Statistics.var(state::ParticleFilterView) =
6576
sum(get_norm_weights(state) .*
6677
(get_retval.(state.traces) .- mean(state)).^2)

src/view.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
struct ParticleFilterSubState{U,I,L}
2+
traces::SubArray{U,1,Vector{U},I,L}
3+
new_traces::SubArray{U,1,Vector{U},I,L}
4+
log_weights::SubArray{Float64,1,Vector{Float64},I,L}
5+
parents::SubArray{Int,1,Vector{Int},I,L}
6+
end
7+
8+
Gen.get_traces(state::ParticleFilterSubState) = state.traces
9+
Gen.get_log_weights(state::ParticleFilterSubState) = state.log_weights
10+
11+
const ParticleFilterView{U} =
12+
Union{ParticleFilterState{U}, ParticleFilterSubState{U}} where {U}
13+
14+
function Base.view(state::ParticleFilterState{U},
15+
indices::AbstractVector) where {U}
16+
L = Base.viewindexing((indices,)) == IndexLinear()
17+
return ParticleFilterSubState{U,Tuple{typeof(indices)},L}(
18+
view(state.traces, indices),
19+
view(state.new_traces, indices),
20+
view(state.log_weights, indices),
21+
view(state.parents, indices)
22+
)
23+
end
24+
25+
Base.getindex(state::ParticleFilterState, indices) =
26+
Base.view(state, indices)
27+
28+
Base.firstindex(state::ParticleFilterState) = 1
29+
Base.lastindex(state::ParticleFilterState) = length(state.traces)

test/runtests.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ state = pf_update!(state, (10,), (UnknownChange(),), choicemap(),
8787
@test all([w != 0 for w in get_log_weights(state)])
8888
end
8989

90+
@testset "Update with different proposals per view" begin
91+
state = pf_initialize(line_model, (0,), choicemap(), 100)
92+
substate = pf_update!(state[1:50], (10,), (UnknownChange(),), generate_line(10))
93+
@test all([tr[:line => 10 => :y] == 0 for tr in get_traces(substate)])
94+
@test all([w != 0 for w in get_log_weights(substate)])
95+
substate = pf_update!(state[51:end], (10,), (UnknownChange(),),
96+
generate_line(10), outlier_propose, (10,))
97+
@test all([tr[:line => 10 => :y] == 0 for tr in get_traces(substate)])
98+
@test all([tr[:line => 10 => :outlier] == false for tr in get_traces(substate)])
99+
@test all([w != 0 for w in get_log_weights(state)])
100+
end
101+
90102
end
91103

92104
@testset "Particle resampling" begin
@@ -240,6 +252,38 @@ rel_weights = parse.(Float64, rel_weights)
240252
@test all(isapprox.(new_weights, old_weights .+ rel_weights; atol=1e-3))
241253
end
242254

255+
@testset "Rejuvenation on separate views" begin
256+
# Log which particles were rejuvenated
257+
buffer = IOBuffer()
258+
logger = SimpleLogger(buffer, Logging.Debug)
259+
state = pf_initialize(line_model, (10,), generate_line(10, 1.), 100)
260+
old_traces = get_traces(state)[1:50]
261+
old_weights = get_log_weights(state)[51:end]
262+
263+
with_logger(logger) do
264+
pf_move_accept!(state[1:50], metropolis_hastings, (select(:slope),), 1)
265+
pf_move_reweight!(state[51:end], move_reweight, (select(:slope),), 1)
266+
end
267+
268+
# Extract acceptances and relative weights from debug log
269+
lines = split(String(take!(buffer)), "\n")
270+
a_lines = filter(s -> occursin("Accepted: ", s), lines)
271+
accepts = [match(r".*Accepted: (\w+).*", l).captures[1] for l in a_lines]
272+
accepts = parse.(Bool, accepts)
273+
r_lines = filter(s -> occursin("Rel. Weight: ", s), lines)
274+
rel_weights = [match(r".*Rel\. Weight: (.+)\s*", l).captures[1] for l in r_lines]
275+
rel_weights = parse.(Float64, rel_weights)
276+
277+
# Check that only traces that were accepted are rejuvenated
278+
new_traces = get_traces(state)[1:50]
279+
@test all(a ? t1 !== t2 : t1 === t2
280+
for (a, t1, t2) in zip(accepts, old_traces, new_traces))
281+
# Check that weights are adjusted accordingly
282+
new_weights = get_log_weights(state)[51:end]
283+
@test all(isapprox.(new_weights, old_weights .+ rel_weights; atol=1e-3))
284+
285+
end
286+
243287
end
244288

245289
@testset "Utility functions" begin

0 commit comments

Comments
 (0)