@@ -68,7 +68,7 @@ ndown(g, total) = ne(g)
6868Return
6969"""
7070function count_moves (g, κ, balance, prior, score, coldness, total, dir= :both )
71- s1 = s2 = 0.0
71+ s1 = s2 = Δscorevalue1 = Δscorevalue2 = 0.0
7272 x1 = y1 = x2 = y2 = 0
7373 T1 = Int[]
7474 H2 = Int[]
@@ -85,14 +85,16 @@ function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
8585 valid = (isclique (g, NAyxT) && isblocked (g, y, x, NAyxT))
8686 if valid
8787 PAy = parents (g, y)
88- s = balance (prior (total, total+ 1 )* exp (coldness* Δscoreinsert (score, NAyxT ∪ PAy, x, y, T)))
88+ Δscorevalue = Δscoreinsert (score, NAyxT ∪ PAy, x, y, T)
89+ s = balance (prior (total, total+ 1 )* exp (coldness* Δscorevalue))
8990 else
9091 s = 0.0
9192 end
9293 @assert s >= 0
9394 if valid && rand () > s1/ (s1 + s) # sequentially draw sample
9495 x1, y1 = x, y
9596 T1 = T
97+ Δscorevalue1 = Δscorevalue
9698 end
9799 s1 = s1 + s
98100 end
@@ -108,25 +110,27 @@ function count_moves(g, κ, balance, prior, score, coldness, total, dir=:both)
108110 if valid
109111 PAy = parents (g, y)
110112 PAy⁻ = setdiff (PAy, x)
111- s = balance (prior (total, total- 1 )* exp (coldness* Δscoredelete (score, NAyx_H ∪ PAy⁻, x, y, H)))
113+ Δscorevalue = Δscoredelete (score, NAyx_H ∪ PAy⁻, x, y, H)
114+ s = balance (prior (total, total- 1 )* exp (coldness* Δscorevalue))
112115 else
113116 s = 0.0
114117 end
115118 @assert s >= 0
116119 if valid && rand () > s2/ (s2 + s)
117120 x2, y2 = x, y
118121 H2 = H
122+ Δscorevalue2 = Δscorevalue
119123 end
120124 s2 = s2 + s
121125 end
122126 end
123127 end
124128 end
125- s1, s2, (x1, y1, T1), (x2, y2, H2)
129+ s1, s2, Δscorevalue1, Δscorevalue2, (x1, y1, T1), (x2, y2, H2)
126130end
127131
128132function count_moves_new (g, κ, balance, prior, score, coldness, total, dir= :both )
129- s1 = s2 = 0.0
133+ s1 = s2 = Δscorevalue1 = Δscorevalue2 = 0.0
130134 x1 = y1 = x2 = y2 = 0
131135 T1 = Int[]
132136 H2 = Int[]
@@ -147,10 +151,12 @@ function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:bot
147151 # to hide complexity
148152 # or just Δscoreinsert(score, g, op)
149153 # and op contains all necessary stuff e.g. NAyxT and so on
150- s = balance (prior (total, total+ 1 )* exp (coldness* Δscoreinsert (score, NAyxT ∪ PAy, x, y, T)))
154+ Δscorevalue = Δscoreinsert (score, NAyxT ∪ PAy, x, y, T)
155+ s = balance (prior (total, total+ 1 )* exp (coldness* Δscorevalue))
151156 if rand () > s1/ (s1 + s) # sequentially draw sample
152157 x1, y1 = x, y
153158 T1 = T
159+ Δscorevalue1 = Δscorevalue
154160 end
155161 s1 = s1 + s
156162 end
@@ -161,17 +167,19 @@ function count_moves_new(g, κ, balance, prior, score, coldness, total, dir=:bot
161167 PAy⁻ = setdiff (PAy, x)
162168 # I would prefer Δscoredelete(score, g, x, y, H) as above
163169 NAyx_H = setdiff (adj_neighbors (g, x, y), H)
164- s = balance (prior (total, total- 1 )* exp (coldness* Δscoredelete (score, NAyx_H ∪ PAy⁻, x, y, H)))
170+ Δscorevalue = Δscoredelete (score, NAyx_H ∪ PAy⁻, x, y, H)
171+ s = balance (prior (total, total- 1 )* exp (coldness* Δscorevalue))
165172 if rand () > s2/ (s2 + s)
166173 x2, y2 = x, y
167174 H2 = H
175+ Δscorevalue2 = Δscorevalue
168176 end
169177 s2 = s2 + s
170178 end
171179 end
172180 end
173181 end
174- s1, s2, (x1, y1, T1), (x2, y2, H2)
182+ s1, s2, Δscorevalue1, Δscorevalue2, (x1, y1, T1), (x2, y2, H2)
175183end
176184
177185"""
182190Run the causal zigzag algorithm starting in a cpdag `(G, t)` with `t` oriented or unoriented edges,
183191the balance function `balance ∈ {metropolis_balance, barker_balance, sqrt}`, `score` function (see `ges` algorithm)
184192coldness parameter for iterations. `σ = 1.0, ρ = 0.0` gives purely diffusive behaviour, `σ = 0.0, ρ = 1.0` gives Zig-Zag behaviour.
193+
194+ Returns a vector of tuples with information, each containing a graph, spent time, current direction, number of edges and the score.
185195"""
186196function causalzigzag (n, G = (DiGraph (n), 0 ); balance = metropolis_balance, prior = (_,_)-> 1.0 , score= UniformScore (),
187197 coldness = 1.0 , σ = 0.0 , ρ = 1.0 , naive= false ,
@@ -191,13 +201,14 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
191201 κ = n - 1
192202 @warn " Truncate κ to $κ "
193203 end
194- gs = Vector {Tuple{typeof(g),Float64,Int,Int}} ()
204+ gs = Vector {Tuple{typeof(g),Float64,Int,Int,Float64 }} ()
195205 dir = 1
196206 global traversals = 0
197207 global tempty = 0.0
198208 τ = 0.0
199209 secs = 0.0
200210 emax = n* κ÷ 2
211+ scorevalue = 0.0
201212 @showprogress for iter in 1 : iterations
202213 τ = 0.0
203214 total_old = total
@@ -208,17 +219,18 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
208219 traversals += 1
209220 end
210221
222+ Δscorevalue1 = Δscorevalue2 = 0.0
211223 if ! naive
212224 if score isa UniformScore
213225 s1, s2, up1, down1 = count_moves_uniform (g, κ)
214226 total < emax && (s1 *= balance (prior (total, total+ 1 )))
215227 total > 0 && (s2 *= balance (prior (total, total- 1 )))
216228
217229 else
218- s1, s2, up1, down1 = count_moves_new (g, κ, balance, prior, score, coldness, total)
230+ s1, s2, Δscorevalue1, Δscorevalue2, up1, down1 = count_moves_new (g, κ, balance, prior, score, coldness, total)
219231 end
220232 else
221- s1, s2, up1, down1 = count_moves (g, κ, balance, prior, score, coldness, total)
233+ s1, s2, Δscorevalue1, Δscorevalue2, up1, down1 = count_moves (g, κ, balance, prior, score, coldness, total)
222234 end
223235 λbar = max (dir* (- s1 + s2), 0.0 )
224236 λrw = (s1 + s2)
@@ -247,7 +259,8 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
247259 x, y, T = up1
248260 @assert x != y
249261 total == 0 && (tempty += τ)
250- save && push! (gs, (g, τ, dir, total))
262+ save && push! (gs, (g, τ, dir, total, scorevalue))
263+ scorevalue += Δscorevalue1
251264 total += 1
252265 secs += @elapsed begin
253266 if ! naive
@@ -264,7 +277,8 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
264277 x, y, H = down1
265278 @assert x != y
266279 total == 0 && (tempty += τ)
267- save && push! (gs, (g, τ, dir, total))
280+ save && push! (gs, (g, τ, dir, total, scorevalue))
281+ scorevalue += Δscorevalue2
268282 total -= 1
269283 secs += @elapsed begin
270284 if ! naive
@@ -282,7 +296,7 @@ function causalzigzag(n, G = (DiGraph(n), 0); balance = metropolis_balance, prio
282296 x = y = 0
283297 dir *= - 1
284298 total == 0 && (tempty += τ)
285- save && push! (gs, (g, τ, dir, total))
299+ save && push! (gs, (g, τ, dir, total, scorevalue ))
286300 break
287301 end # break
288302 verbose && println (total_old, dir_old == 1 ? " ↑" : " ↓" , total, " $x => $y " , round (τ, digits= 8 ))
@@ -299,11 +313,12 @@ end
299313function unzipgs (gs)
300314 graphs = first .(gs)
301315 graph_pairs = vpairs .(graphs)
302- hs = map (last, gs)
316+ scs = map (last, gs)
317+ hs = map (x-> getindex (x, 4 ), gs)
303318 τs = map (x-> getindex (x, 2 ), gs)
304319 ws = normalize (τs, 1 )
305320 ts = cumsum (ws)
306- (;graphs, graph_pairs, hs, τs, ws, ts)
321+ (;graphs, graph_pairs, hs, τs, ws, ts, scs )
307322end
308323
309- const randcpdag = causalzigzag
324+ const randcpdag = causalzigzag
0 commit comments