-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcascade.py
50 lines (43 loc) · 1.65 KB
/
cascade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import math
import numpy as np
def observe_cascade(c, source, q, method='uniform', source_includable=False):
"""graph_tool version of make_partial_cascade
"""
all_infection = np.nonzero(c != -1)[0]
if not source_includable:
all_infection = list(set(all_infection) - {source})
num_obs = int(math.ceil(len(all_infection) * q))
if num_obs < 2:
num_obs = 2
if method == 'uniform':
return np.random.permutation(all_infection)[:num_obs]
elif method == 'late':
return np.argsort(c)[-num_obs:]
# @profile
def gen_nontrivial_cascade(g, p, q, model='ic', source=None, return_tree=False, source_includable=False):
assert model in {'ic', 'si', 'sp', 'ct'}
while True:
if model == 'ic':
from ic import simulate_cascade
rts = simulate_cascade(g, p, source=source, return_tree=return_tree)
elif model == 'si':
from si import gen_cascade
rts = gen_cascade(g, p, source=source)
elif model == 'sp':
from sp import gen_cascade
rts = gen_cascade(g, source=source)
elif model == 'ct':
from ctic import gen_cascade
rts = gen_cascade(g, source=source)
source, c = rts[:2]
if return_tree:
tree = rts[2]
obs_nodes = observe_cascade(c, source, q, method='uniform',
source_includable=source_includable)
cascade_size = np.sum(c != -1)
if cascade_size >= 5: # avoid small cascade
break
if return_tree:
return c, source, obs_nodes, tree
else:
return c, source, obs_nodes