Skip to content

Commit b4628f8

Browse files
committed
Add tests to check LS HMM of _tskit.lshmm compared to BEAGLE
1 parent b6f9872 commit b4628f8

File tree

1 file changed

+301
-0
lines changed

1 file changed

+301
-0
lines changed

python/tests/test_imputation.py

+301
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
"""
2+
Tests for genotype imputation (forward and Baum-Welsh algorithms).
3+
"""
4+
import io
5+
6+
import numpy as np
7+
8+
import _tskit
9+
import tskit
10+
11+
12+
# A toy tree sequence containing 3 diploid individuals with 5 sites and 5 mutations.
13+
# Two toy query haplotypes are targets for imputation.
14+
15+
toy_ts_nodes_text = """\
16+
id is_sample time population individual metadata
17+
0 1 0.000000 0 0
18+
1 1 0.000000 0 0
19+
2 1 0.000000 0 1
20+
3 1 0.000000 0 1
21+
4 1 0.000000 0 2
22+
5 1 0.000000 0 2
23+
6 0 0.029768 0 -1
24+
7 0 0.133017 0 -1
25+
8 0 0.223233 0 -1
26+
9 0 0.651586 0 -1
27+
10 0 0.698831 0 -1
28+
11 0 2.114867 0 -1
29+
12 0 4.322031 0 -1
30+
13 0 7.432311 0 -1
31+
"""
32+
33+
toy_ts_edges_text = """\
34+
left right parent child metadata
35+
0.000000 1000000.000000 6 0
36+
0.000000 1000000.000000 6 3
37+
0.000000 1000000.000000 7 2
38+
0.000000 1000000.000000 7 5
39+
0.000000 1000000.000000 8 1
40+
0.000000 1000000.000000 8 4
41+
0.000000 781157.000000 9 6
42+
0.000000 781157.000000 9 7
43+
0.000000 505438.000000 10 8
44+
0.000000 505438.000000 10 9
45+
505438.000000 549484.000000 11 8
46+
505438.000000 549484.000000 11 9
47+
781157.000000 1000000.000000 12 6
48+
781157.000000 1000000.000000 12 7
49+
549484.000000 1000000.000000 13 8
50+
549484.000000 781157.000000 13 9
51+
781157.000000 1000000.000000 13 12
52+
"""
53+
54+
toy_ts_sites_text = """\
55+
position ancestral_state metadata
56+
200000.000000 A
57+
300000.000000 C
58+
520000.000000 G
59+
600000.000000 T
60+
900000.000000 A
61+
"""
62+
63+
toy_ts_mutations_text = """\
64+
site node time derived_state parent metadata
65+
0 9 unknown G -1
66+
1 8 unknown A -1
67+
2 9 unknown T -1
68+
3 9 unknown C -1
69+
4 12 unknown C -1
70+
"""
71+
72+
toy_ts_individuals_text = """\
73+
flags
74+
0
75+
0
76+
0
77+
"""
78+
79+
toy_query_haplotypes_01 = np.array(
80+
[
81+
[
82+
1,
83+
0,
84+
-1,
85+
0,
86+
0,
87+
],
88+
[
89+
0,
90+
1,
91+
-1,
92+
1,
93+
0,
94+
],
95+
],
96+
dtype=np.int32,
97+
)
98+
99+
toy_query_haplotypes_ACGT = np.array(
100+
[
101+
[2, 1, -1, 3, 0], # GCTA
102+
[0, 0, -1, 1, 0], # AACA
103+
],
104+
dtype=np.int32,
105+
)
106+
107+
108+
def get_toy_data():
109+
ref_ts = tskit.load_text(
110+
nodes=io.StringIO(toy_ts_nodes_text),
111+
edges=io.StringIO(toy_ts_edges_text),
112+
sites=io.StringIO(toy_ts_sites_text),
113+
mutations=io.StringIO(toy_ts_mutations_text),
114+
individuals=io.StringIO(toy_ts_individuals_text),
115+
strict=False,
116+
)
117+
query_h = toy_query_haplotypes_ACGT
118+
return [ref_ts, query_h]
119+
120+
121+
def get_tskit_forward_backward_matrices(ts, h):
122+
m = ts.num_sites
123+
fm = _tskit.CompressedMatrix(ts._ll_tree_sequence)
124+
bm = _tskit.CompressedMatrix(ts._ll_tree_sequence)
125+
ls_hmm = _tskit.LsHmm(
126+
ts._ll_tree_sequence, np.zeros(m) + 0.1, np.zeros(m) + 0.1, acgt_alleles=True
127+
)
128+
ls_hmm.forward_matrix(h, fm)
129+
ls_hmm.backward_matrix(h, fm.normalisation_factor, bm)
130+
return [fm.decode(), bm.decode()]
131+
132+
133+
# BEAGLE 4.1 was run on the toy data set above using default parameters.
134+
#
135+
# In the query VCF, the site at position 520,000 was redacted and then imputed.
136+
# Note that the ancestral allele in the simulated tree sequence is
137+
# treated as the REF in the VCFs.
138+
#
139+
# The following are the forward probability matrices and backward probability
140+
# matrices calculated when imputing into the third individual above. There are
141+
# two sets of matrices, one for each haplotype.
142+
#
143+
# Notes about calculations:
144+
# n = number of haplotypes in ref. panel
145+
# M = number of markers
146+
# m = index of marker (site)
147+
# h = index of haplotype in ref. panel
148+
#
149+
# In forward probability matrix,
150+
# fwd[m][h] = emission prob., if m = 0 (first marker)
151+
# fwd[m][h] = emission prob. * (scale * fwd[m - 1][h] + shift), otherwise
152+
# where scale = (1 - switch prob.)/sum of fwd[m - 1],
153+
# and shift = switch prob./n.
154+
#
155+
# In backward probability matrix,
156+
# bwd[m][h] = 1, if m = M - 1 (last marker) // DON'T SEE THIS IN BEAGLE
157+
# unadj. bwd[m][h] = emission prob. / n
158+
# bwd[m][h] = (unadj. bwd[m][h] + shift) * scale, otherwise
159+
# where scale = (1 - switch prob.)/sum of unadj. bwd[m],
160+
# and shift = switch prob./n.
161+
#
162+
# For each site, the sum of backward value over all haplotypes is calculated
163+
# before scaling and shifting.
164+
165+
beagle_fwd_matrix_text_1 = """
166+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
167+
0,0,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,0.999900,0.999900,
168+
0,1,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,1.000000,0.000100,
169+
0,2,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,1.999900,0.999900,
170+
0,3,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,2.999800,0.999900,
171+
0,4,0.000000,1.000000,0.999900,0.000100,0,1,0.000000,1.000000,2.999900,0.000100,
172+
0,5,0.000000,1.000000,0.999900,0.000100,1,1,0.000000,1.000000,3.999800,0.999900,
173+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166650,0.166650,
174+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166667,0.000017,
175+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333317,0.166650,
176+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.499967,0.166650,
177+
1,4,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.499983,0.000017,
178+
1,5,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.666633,0.166650,
179+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017,
180+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650,
181+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017,
182+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017,
183+
2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650,
184+
2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017,
185+
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017,
186+
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650,
187+
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017,
188+
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017,
189+
3,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650,
190+
3,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017,
191+
"""
192+
193+
beagle_bwd_matrix_text_1 = """
194+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
195+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
196+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
197+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
198+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
199+
3,4,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
200+
3,5,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
201+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
202+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
203+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
204+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
205+
2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
206+
2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
207+
1,0,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
208+
1,1,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
209+
1,2,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
210+
1,3,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
211+
1,4,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
212+
1,5,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
213+
0,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
214+
0,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.666633,0.166667,
215+
0,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
216+
0,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
217+
0,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.666633,0.166667,
218+
0,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.666633,0.166667,
219+
"""
220+
221+
beagle_fwd_matrix_text_2 = """
222+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
223+
0,0,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,0.000100,0.000100,
224+
0,1,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,1.000000,0.999900,
225+
0,2,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000100,0.000100,
226+
0,3,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,1.000200,0.000100,
227+
0,4,0.000000,1.000000,0.999900,0.000100,0,0,0.000000,1.000000,2.000100,0.999900,
228+
0,5,0.000000,1.000000,0.999900,0.000100,1,0,0.000000,1.000000,2.000200,0.000100,
229+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.000017,0.000017,
230+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.166667,0.166650,
231+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166683,0.000017,
232+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166700,0.000017,
233+
1,4,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.333350,0.166650,
234+
1,5,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.333367,0.000017,
235+
2,0,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.166650,0.166650,
236+
2,1,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.166667,0.000017,
237+
2,2,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.333317,0.166650,
238+
2,3,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.499967,0.166650,
239+
2,4,1.000000,0.000000,0.999900,0.000100,0,1,0.166667,0.000000,0.499983,0.000017,
240+
2,5,1.000000,0.000000,0.999900,0.000100,1,1,0.166667,0.000000,0.666633,0.166650,
241+
3,0,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.000017,0.000017,
242+
3,1,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.166667,0.166650,
243+
3,2,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166683,0.000017,
244+
3,3,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.166700,0.000017,
245+
3,4,1.000000,0.000000,0.999900,0.000100,0,0,0.166667,0.000000,0.333350,0.166650,
246+
3,5,1.000000,0.000000,0.999900,0.000100,1,0,0.166667,0.000000,0.333367,0.000017,
247+
"""
248+
249+
beagle_bwd_matrix_text_2 = """
250+
m,h,probRec,probNoRec,noErrProb,errProb,refAl,queryAl,shiftFac,scaleFac,sumSite,val,
251+
3,0,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
252+
3,1,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
253+
3,2,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
254+
3,3,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
255+
3,4,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
256+
3,5,-1,-1,-1,-1,-1,-1,-1,-1,-1,1.000000,
257+
2,0,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
258+
2,1,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
259+
2,2,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
260+
2,3,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
261+
2,4,1.000000,0.000000,0.999900,0.000100,0,0,0.000000,0.166667,0.333367,0.166667,
262+
2,5,1.000000,0.000000,0.999900,0.000100,1,0,0.000000,0.166667,0.333367,0.166667,
263+
1,0,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
264+
1,1,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.666633,0.166667,
265+
1,2,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
266+
1,3,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
267+
1,4,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.666633,0.166667,
268+
1,5,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.666633,0.166667,
269+
0,0,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
270+
0,1,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.333367,0.166667,
271+
0,2,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
272+
0,3,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
273+
0,4,1.000000,0.000000,0.999900,0.000100,0,1,0.000000,0.166667,0.333367,0.166667,
274+
0,5,1.000000,0.000000,0.999900,0.000100,1,1,0.000000,0.166667,0.333367,0.166667,
275+
"""
276+
277+
278+
def parse_matrix(csv_text):
279+
# This returns a record array, which is essentially the same as a
280+
# pandas dataframe, which we can access via df["m"] etc.
281+
return np.lib.npyio.recfromcsv(io.StringIO(csv_text))
282+
283+
284+
def test_toy_example():
285+
ref_ts, query = get_toy_data()
286+
print(list(ref_ts.haplotypes()))
287+
print(ref_ts)
288+
print(query)
289+
tskit_fwd, tskit_bwd = get_tskit_forward_backward_matrices(ref_ts, query[0])
290+
beagle_fwd = parse_matrix(beagle_fwd_matrix_text_1)
291+
beagle_bwd = parse_matrix(beagle_bwd_matrix_text_1)
292+
print("Forward probability matrix")
293+
print("tskit")
294+
print(tskit_fwd)
295+
print("beagle")
296+
print(beagle_fwd["val"].reshape((4, 6)))
297+
print("Backward probability matrix")
298+
print("tskit")
299+
print(tskit_bwd)
300+
print("beagle")
301+
print(beagle_bwd["val"].reshape((4, 6)))

0 commit comments

Comments
 (0)