Skip to content

Commit dbf7643

Browse files
authored
Merge pull request #133 from mschauer/incscore
Incrementally compute score + test
2 parents 7459a96 + 1ee2e30 commit dbf7643

File tree

4 files changed

+66
-19
lines changed

4 files changed

+66
-19
lines changed

src/sampler.jl

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ ndown(g, total) = ne(g)
6868
Return
6969
"""
7070
function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
71-
s1 = s2 = 0.0
71+
s1 = s2 = Δscorevalue1 = Δscorevalue2 = 0.0
7272
x1 = y1 = x2 = y2 = 0
7373
T1 = Int[]
7474
H2 = Int[]
@@ -85,14 +85,16 @@ function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
8585
valid = (isclique(g, NAyxT) && isblocked(g, y, x, NAyxT))
8686
if valid
8787
PAy = parents(g, y)
88-
s = balance(prior(total, total+1)*exp(coldness*Δscoreinsert(score, NAyxT PAy, x, y, T)))
88+
Δscorevalue = Δscoreinsert(score, NAyxT PAy, x, y, T)
89+
s = balance(prior(total, total+1)*exp(coldness*Δscorevalue))
8990
else
9091
s = 0.0
9192
end
9293
@assert s >= 0
9394
if valid && rand() > s1/(s1 + s) # sequentially draw sample
9495
x1, y1 = x, y
9596
T1 = T
97+
Δscorevalue1 = Δscorevalue
9698
end
9799
s1 = s1 + s
98100
end
@@ -108,25 +110,27 @@ function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
108110
if valid
109111
PAy = parents(g, y)
110112
PAy⁻ = setdiff(PAy, x)
111-
s = balance(prior(total, total-1)*exp(coldness*Δscoredelete(score, NAyx_H PAy⁻, x, y, H)))
113+
Δscorevalue = Δscoredelete(score, NAyx_H PAy⁻, x, y, H)
114+
s = balance(prior(total, total-1)*exp(coldness*Δscorevalue))
112115
else
113116
s = 0.0
114117
end
115118
@assert s >= 0
116119
if valid && rand() > s2/(s2 + s)
117120
x2, y2 = x, y
118121
H2 = H
122+
Δscorevalue2 = Δscorevalue
119123
end
120124
s2 = s2 + s
121125
end
122126
end
123127
end
124128
end
125-
s1, s2, (x1, y1, T1), (x2, y2, H2)
129+
s1, s2, Δscorevalue1, Δscorevalue2, (x1, y1, T1), (x2, y2, H2)
126130
end
127131

128132
function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:both)
129-
s1 = s2 = 0.0
133+
s1 = s2 = Δscorevalue1 = Δscorevalue2 = 0.0
130134
x1 = y1 = x2 = y2 = 0
131135
T1 = Int[]
132136
H2 = Int[]
@@ -147,10 +151,12 @@ function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:bot
147151
# to hide complexity
148152
# or just Δscoreinsert(score, g, op)
149153
# and op contains all necessary stuff e.g. NAyxT and so on
150-
s = balance(prior(total, total+1)*exp(coldness*Δscoreinsert(score, NAyxT PAy, x, y, T)))
154+
Δscorevalue = Δscoreinsert(score, NAyxT PAy, x, y, T)
155+
s = balance(prior(total, total+1)*exp(coldness*Δscorevalue))
151156
if rand() > s1/(s1 + s) # sequentially draw sample
152157
x1, y1 = x, y
153158
T1 = T
159+
Δscorevalue1 = Δscorevalue
154160
end
155161
s1 = s1 + s
156162
end
@@ -161,17 +167,19 @@ function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:bot
161167
PAy⁻ = setdiff(PAy, x)
162168
# I would prefer Δscoredelete(score, g, x, y, H) as above
163169
NAyx_H = setdiff(adj_neighbors(g, x, y), H)
164-
s = balance(prior(total, total-1)*exp(coldness*Δscoredelete(score, NAyx_H PAy⁻, x, y, H)))
170+
Δscorevalue = Δscoredelete(score, NAyx_H PAy⁻, x, y, H)
171+
s = balance(prior(total, total-1)*exp(coldness*Δscorevalue))
165172
if rand() > s2/(s2 + s)
166173
x2, y2 = x, y
167174
H2 = H
175+
Δscorevalue2 = Δscorevalue
168176
end
169177
s2 = s2 + s
170178
end
171179
end
172180
end
173181
end
174-
s1, s2, (x1, y1, T1), (x2, y2, H2)
182+
s1, s2, Δscorevalue1, Δscorevalue2, (x1, y1, T1), (x2, y2, H2)
175183
end
176184

177185
"""
@@ -182,6 +190,8 @@ end
182190
Run the causal zigzag algorithm starting in a cpdag `(G, t)` with `t` oriented or unoriented edges,
183191
the balance function `balance ∈ {metropolis_balance, barker_balance, sqrt}`, `score` function (see `ges` algorithm)
184192
coldness parameter for iterations. `σ = 1.0, ρ = 0.0` gives purely diffusive behaviour, `σ = 0.0, ρ = 1.0` gives Zig-Zag behaviour.
193+
194+
Returns a vector of tuples with information, each containing a graph, spent time, current direction, number of edges and the score.
185195
"""
186196
function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prior = (_,_)->1.0, score=UniformScore(),
187197
coldness = 1.0, σ = 0.0, ρ = 1.0, naive=false,
@@ -191,13 +201,14 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
191201
κ = n - 1
192202
@warn "Truncate κ to "
193203
end
194-
gs = Vector{Tuple{typeof(g),Float64,Int,Int}}()
204+
gs = Vector{Tuple{typeof(g),Float64,Int,Int,Float64}}()
195205
dir = 1
196206
global traversals = 0
197207
global tempty = 0.0
198208
τ = 0.0
199209
secs = 0.0
200210
emax = n*κ÷2
211+
scorevalue = 0.0
201212
@showprogress for iter in 1:iterations
202213
τ = 0.0
203214
total_old = total
@@ -208,17 +219,18 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
208219
traversals += 1
209220
end
210221

222+
Δscorevalue1 = Δscorevalue2 = 0.0
211223
if !naive
212224
if score isa UniformScore
213225
s1, s2, up1, down1 = count_moves_uniform(g, κ)
214226
total < emax && (s1 *= balance(prior(total, total+1)))
215227
total > 0 && (s2 *= balance(prior(total, total-1)))
216228

217229
else
218-
s1, s2, up1, down1 = count_moves_new(g, κ, balance, prior, score, coldness, total)
230+
s1, s2, Δscorevalue1, Δscorevalue2, up1, down1 = count_moves_new(g, κ, balance, prior, score, coldness, total)
219231
end
220232
else
221-
s1, s2, up1, down1 = count_moves(g, κ, balance, prior, score, coldness, total)
233+
s1, s2, Δscorevalue1, Δscorevalue2, up1, down1 = count_moves(g, κ, balance, prior, score, coldness, total)
222234
end
223235
λbar = max(dir*(-s1 + s2), 0.0)
224236
λrw = (s1 + s2)
@@ -247,7 +259,8 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
247259
x, y, T = up1
248260
@assert x != y
249261
total == 0 && (tempty += τ)
250-
save && push!(gs, (g, τ, dir, total))
262+
save && push!(gs, (g, τ, dir, total, scorevalue))
263+
scorevalue += Δscorevalue1
251264
total += 1
252265
secs += @elapsed begin
253266
if !naive
@@ -264,7 +277,8 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
264277
x, y, H = down1
265278
@assert x != y
266279
total == 0 && (tempty += τ)
267-
save && push!(gs, (g, τ, dir, total))
280+
save && push!(gs, (g, τ, dir, total, scorevalue))
281+
scorevalue += Δscorevalue2
268282
total -= 1
269283
secs += @elapsed begin
270284
if !naive
@@ -282,7 +296,7 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
282296
x = y = 0
283297
dir *= -1
284298
total == 0 && (tempty += τ)
285-
save && push!(gs, (g, τ, dir, total))
299+
save && push!(gs, (g, τ, dir, total, scorevalue))
286300
break
287301
end # break
288302
verbose && println(total_old, dir_old == 1 ? "" : "", total, " $x => $y ", round(τ, digits=8))
@@ -299,11 +313,12 @@ end
299313
function unzipgs(gs)
300314
graphs = first.(gs)
301315
graph_pairs = vpairs.(graphs)
302-
hs = map(last, gs)
316+
scs = map(last, gs)
317+
hs = map(x->getindex(x, 4), gs)
303318
τs = map(x->getindex(x, 2), gs)
304319
ws = normalize(τs, 1)
305320
ts = cumsum(ws)
306-
(;graphs, graph_pairs, hs, τs, ws, ts)
321+
(;graphs, graph_pairs, hs, τs, ws, ts, scs)
307322
end
308323

309-
const randcpdag = causalzigzag
324+
const randcpdag = causalzigzag

test/gesvsR.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ using Random
3232
@test s sb
3333
#g2c, sc, (t1c, t2c) = ges(X; penalty, parallel=true)
3434
@test score_R score_dag(DiGraph(d), GaussianScore(C, n, penalty)) + s
35-
@show score_R score_dag(pdag2dag!(copy(g2)), GaussianScore(C, n, penalty))
36-
@show score_R score_dag(pdag2dag!(copy(g3)), GaussianScore(C, n, penalty))
35+
@test score_R score_dag(pdag2dag!(copy(g2)), GaussianScore(C, n, penalty))
36+
@test score_R score_dag(pdag2dag!(copy(g3)), GaussianScore(C, n, penalty))
3737

3838
@test isempty(symdiff(vpairs(g2), vpairs(g2b)))
3939

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ include("exact.jl")
66
include("operators.jl")
77
include("ges.jl")
88
include("gesvsR.jl")
9+
include("sampler.jl")
910
include("gensearch.jl")
1011
include("cpdag.jl")
1112
include("skeleton.jl")

test/sampler.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using Random, CausalInference, Statistics, Test, Graphs
2+
@testset "Zig-Zag" begin
3+
Random.seed!(1)
4+
5+
N = 2000 # number of data points
6+
7+
# define simple linear model with added noise
8+
x = randn(N)
9+
v = x + randn(N)*0.25
10+
w = x + randn(N)*0.25
11+
z = v + w + randn(N)*0.25
12+
s = z + randn(N)*0.25
13+
14+
df = (x=x, v=v, w=w, z=z, s=s)
15+
iterations = 5_000
16+
n = length(df) # vertices
17+
κ = n - 1 # max degree
18+
penalty = 2.0 # increase to get more edges in truth
19+
Random.seed!(101)
20+
C = cor(CausalInference.Tables.matrix(df))
21+
score = GaussianScore(C, N, penalty)
22+
gs = @time causalzigzag(n; score, κ, iterations)
23+
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
24+
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
25+
26+
# maximum aposteriori estimate
27+
@test first(posterior).first == [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
28+
# score of last sample
29+
@test score_dag(pdag2dag!(copy(graphs[end])), score) scores[end] + score_dag(DiGraph(n), score)
30+
31+
end #testset

0 commit comments

Comments
 (0)