From 05af0b0ba120c57bb05cfeb3ce4fa2037a2b6c4e Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 11 Oct 2023 13:28:11 +0100 Subject: [PATCH 1/6] Working all-nodes check for special case --- python/tests/test_haplotype_matching.py | 364 ++++++++++++++++++------ 1 file changed, 276 insertions(+), 88 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dcc1d684fb..c3a9e3134c 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -99,7 +99,15 @@ class LsHmmAlgorithm: """ def __init__( - self, ts, rho, mu, alleles, n_alleles, precision=10, scale_mutation=False + self, + ts, + rho, + mu, + alleles, + n_alleles, + precision=10, + scale_mutation=False, + match_all_nodes=False, ): self.ts = ts self.mu = mu @@ -109,8 +117,6 @@ def __init__( self.T = [] # indexes in to the T array for each node. self.T_index = np.zeros(ts.num_nodes, dtype=int) - 1 - # The number of nodes underneath each element in the T array. - self.N = np.zeros(ts.num_nodes, dtype=int) # Efficiently compute the allelic state at a site self.allelic_state = np.zeros(ts.num_nodes, dtype=int) - 1 # TreePosition so we can can update T and T_index between trees. @@ -122,6 +128,41 @@ def __init__( self.n_alleles = n_alleles self.alleles = alleles self.scale_mutation_based_on_n_alleles = scale_mutation + self.match_all_nodes = match_all_nodes + + def node_values(self): + """ + Return the current mapping of node->value for each node in the + tree. + """ + d = {} + mapping = {st.tree_node: st.value for st in self.T if st.tree_node != -1} + for u in self.tree.nodes(): + v = u + while v not in mapping: + assert v != -1 + v = self.tree.parent(v) + d[u] = mapping[v] + return d + + def print_state(self): + print("LsHMM state") + print("match_all_nodes =", self.match_all_nodes) + print("Tree =") + node_labels = {} + for u, value in self.node_values().items(): + label = f"{u}" + if self.tree.is_sample(u): + label = f"*{u}*" + label += f":{value:.2g}" + node_labels[u] = label + print(self.tree.draw_text(node_labels=node_labels)) + print("T =") + for vt in self.T: + print("\t", vt) + print("T_index:") + for u in range(self.ts.num_nodes): + print(f"\t{u}\t{self.T_index[u]}") def check_integrity(self): M = [st.tree_node for st in self.T if st.tree_node != -1] @@ -134,6 +175,45 @@ def check_integrity(self): assert j == self.T_index[st.tree_node] def compress(self): + if self.match_all_nodes: + self._compress_tsinfer() + else: + self._compress_parsimony() + # self.print_state() + self.check_integrity() + + def _compress_tsinfer(self): + tree = self.tree + T = self.T + T_index = self.T_index + + T_old = [st.copy() for st in T] + T.clear() + + for st in T_old: + u = st.tree_node + if u != -1: + # We need to find the likelihood of the parent of u. If this is + # the same as u, we can delete it. + v = tree.parent(u) + while v != -1 and T_index[v] == -1: + v = tree.parent(v) + keep = True + if v != -1: + if st.value == T_old[T_index[v]].value: + keep = False + if keep: + T.append(st) + T_index[u] = -1 + + # Sort by decreasing time to ensure postorder. This is used by the + # compressed matrix, downstream + self.T.sort(key=lambda st: -tree.time(st.tree_node)) + + for j, st in enumerate(self.T): + self.T_index[st.tree_node] = j + + def _compress_parsimony(self): tree = self.tree T = self.T T_index = self.T_index @@ -190,13 +270,14 @@ def compute(u, parent_state): T_old = [st.copy() for st in T] T.clear() - T_parent = [] + # Removeing T_parent as it's not needed currently, see note on N[j] below + # T_parent = [] old_state = T_old[T_index[tree.root]].value_index new_state = np.argmax(optimal_set[tree.root]) T.append(ValueTransition(tree_node=tree.root, value=values[new_state])) - T_parent.append(-1) + # T_parent.append(-1) stack = [(tree.root, old_state, new_state, 0)] while len(stack) > 0: u, old_state, new_state, t_parent = stack.pop() @@ -211,14 +292,14 @@ def compute(u, parent_state): if optimal_set[v, new_state] == 0: new_child_state = np.argmax(optimal_set[v]) child_t_parent = len(T) - T_parent.append(t_parent) + # T_parent.append(t_parent) T.append( ValueTransition(tree_node=v, value=values[new_child_state]) ) stack.append((v, old_child_state, new_child_state, child_t_parent)) else: if old_child_state != new_state: - T_parent.append(t_parent) + # T_parent.append(t_parent) T.append( ValueTransition(tree_node=v, value=values[old_child_state]) ) @@ -228,10 +309,13 @@ def compute(u, parent_state): T_index[st.tree_node] = -1 for j, st in enumerate(T): T_index[st.tree_node] = j - self.N[j] = tree.num_samples(st.tree_node) - for j in range(len(T)): - if T_parent[j] != -1: - self.N[T_parent[j]] -= self.N[j] + + # NOTE: we only use the N values in the forward matrix at the moment, + # so simplifying here by calculating them on the fly where needed. + # self.N[j] = tree.num_samples(st.tree_node) + # for j in range(len(T)): + # if T_parent[j] != -1: + # self.N[T_parent[j]] -= self.N[j] def update_tree(self, direction=tskit.FORWARD): """ @@ -333,11 +417,11 @@ def update_probabilities(self, site, haplotype_state): while allelic_state[v] == -1: v = tree.parent(v) assert v != -1 - match = ( + is_match = ( haplotype_state == MISSING or haplotype_state == allelic_state[v] ) # Note that the node u is used only by Viterbi - st.value = self.compute_next_probability(site.id, st.value, match, u) + st.value = self.compute_next_probability(site.id, st.value, is_match, u) # Unset the states allelic_state[tree.root] = -1 @@ -346,7 +430,12 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) + # d1 = self.node_values() self.compress() + # d2 = self.node_values() + # assert d1 == d2 + # print("AFTER COMPRESS") + # self.print_state() s = self.compute_normalisation_factor() for st in self.T: assert st.tree_node != tskit.NULL @@ -413,26 +502,27 @@ class ForwardAlgorithm(LsHmmAlgorithm): The Li and Stephens forward algorithm. """ - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) + def __init__(self, ts, *args, **kwargs): + super().__init__(ts, *args, **kwargs) self.output = CompressedMatrix(ts) def compute_normalisation_factor(self): + d = {st.tree_node: st for st in self.T} + N = np.zeros(self.ts.num_nodes, dtype=int) + for u in self.tree.nodes(order="preorder"): + if u in d: + N[u] = self.tree.num_samples(u) + # Subtract this value from everything above + v = self.tree.parent(u) + while v != -1 and v not in d: + v = self.tree.parent(v) + if v != -1: + N[v] -= N[u] s = 0 - for j, st in enumerate(self.T): + for st in self.T: assert st.tree_node != tskit.NULL - # assert self.N[j] > 0 - s += self.N[j] * st.value + assert N[st.tree_node] > 0 + s += N[st.tree_node] * st.value return s def compute_next_probability(self, site_id, p_last, is_match, node): @@ -489,18 +579,8 @@ class ViterbiAlgorithm(LsHmmAlgorithm): Runs the Li and Stephens Viterbi algorithm. """ - def __init__( - self, ts, rho, mu, alleles, n_alleles, scale_mutation=False, precision=10 - ): - super().__init__( - ts, - rho, - mu, - alleles, - n_alleles, - precision=precision, - scale_mutation=scale_mutation, - ) + def __init__(self, ts, *args, **kwargs): + super().__init__(ts, *args, **kwargs) self.output = ViterbiMatrix(ts) def compute_normalisation_factor(self): @@ -570,6 +650,16 @@ def store_site(self, site, normalisation_factor, value_transitions): self.normalisation_factor[site] = normalisation_factor self.value_transitions[site] = value_transitions + def print_state(self): + print("Compressed matrix state") + for site in range(self.num_sites): + print( + site, + self.normalisation_factor[site], + self.value_transitions[site], + sep="\t", + ) + # Expose the same API as the low-level classes @property @@ -633,12 +723,14 @@ def choose_sample(self, site_id, tree): def traceback(self): # Run the traceback. m = self.ts.num_sites - match = np.zeros(m, dtype=int) + matched = np.zeros(m, dtype=int) recombination_tree = np.zeros(self.ts.num_nodes, dtype=int) - 1 tree = tskit.Tree(self.ts) tree.last() current_node = -1 + # self.print_state() + rr_index = len(self.recombination_required) - 1 for site in reversed(self.ts.sites()): while tree.interval.left > site.position: @@ -654,7 +746,7 @@ def traceback(self): if current_node == -1: current_node = self.choose_sample(site.id, tree) - match[site.id] = current_node + matched[site.id] = current_node # Now traverse up the tree from the current node. The first marked node # we meet tells us whether we need to recombine. @@ -664,6 +756,8 @@ def traceback(self): assert u != -1 if recombination_tree[u] == 1: + # print("recomb_tree = ", recombination_tree) + # print("SWITCHING AT ", site) # Need to switch at the next site. current_node = -1 # Reset the nodes in the recombination tree. @@ -674,7 +768,8 @@ def traceback(self): j -= 1 rr_index = j - return match + # print("MATCHED = ", matched) + return matched def get_site_alleles(ts, h, alleles): @@ -701,7 +796,14 @@ def get_site_alleles(ts, h, alleles): def ls_forward_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False + h, + ts, + rho, + mu, + precision=30, + alleles=None, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=False, ): alleles, n_alleles = get_site_alleles(ts, h, alleles) fa = ForwardAlgorithm( @@ -712,11 +814,21 @@ def ls_forward_tree( n_alleles, precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, + match_all_nodes=match_all_nodes, ) return fa.run(h) -def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles=None): +def ls_backward_tree( + h, + ts, + rho, + mu, + normalisation_factor, + precision=30, + alleles=None, + match_all_nodes=False, +): alleles, n_alleles = get_site_alleles(ts, h, alleles) ba = BackwardAlgorithm( ts, @@ -725,12 +837,20 @@ def ls_backward_tree(h, ts, rho, mu, normalisation_factor, precision=30, alleles alleles, n_alleles, precision=precision, + match_all_nodes=match_all_nodes, ) return ba.run(h, normalisation_factor) def ls_viterbi_tree( - h, ts, rho, mu, precision=30, alleles=None, scale_mutation_based_on_n_alleles=False + h, + ts, + rho, + mu, + precision=30, + alleles=None, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=False, ): alleles, n_alleles = get_site_alleles(ts, h, alleles) va = ViterbiAlgorithm( @@ -741,6 +861,7 @@ def ls_viterbi_tree( n_alleles, precision=precision, scale_mutation=scale_mutation_based_on_n_alleles, + match_all_nodes=match_all_nodes, ) return va.run(h) @@ -798,8 +919,7 @@ def example_parameters_haplotypes(self, ts, seed=42): # yield n, H, s, r, mu def assertAllClose(self, A, B): - """Assert that all entries of two matrices are 'close'""" - assert np.allclose(A, B, rtol=1e-5, atol=1e-8) + np.testing.assert_allclose(A, B, rtol=1e-5, atol=1e-8) # Define a bunch of very small tree-sequences for testing a collection # of parameters on @@ -1028,6 +1148,8 @@ def verify(self, ts): # Now, need to ensure that the likelihood of the preferred path is # the same as ll_tree (and ll). path_tree = cm.traceback() + # print(path) + # print(path_tree) ll_check = ls.path_ll( H, s, @@ -1040,7 +1162,9 @@ def verify(self, ts): # TODO add params to run the various checks -def check_viterbi(ts, h, recombination=None, mutation=None): +def check_viterbi( + ts, h, recombination=None, mutation=None, match_all_nodes=False, compare_lib=True +): h = np.array(h).astype(np.int8) m = ts.num_sites assert len(h) == m @@ -1060,11 +1184,12 @@ def check_viterbi(ts, h, recombination=None, mutation=None): scale_mutation_based_on_n_alleles=False, ) assert np.isscalar(ll) + # print() + # print("ls path = ", path) - cm = ls_viterbi_tree(h, ts, rho=recombination, mu=mutation) - ll_tree = np.sum(np.log10(cm.normalisation_factor)) - assert np.isscalar(ll_tree) - nt.assert_allclose(ll_tree, ll) + cm = ls_viterbi_tree( + h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes + ) # Check that the likelihood of the preferred path is # the same as ll_tree (and ll). @@ -1077,24 +1202,33 @@ def check_viterbi(ts, h, recombination=None, mutation=None): p_mutation=mutation, scale_mutation_based_on_n_alleles=False, ) - nt.assert_allclose(ll_check, ll) + # print(cm) + # print("path tree = ", path_tree) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.ViterbiMatrix(ll_ts) - ls_hmm.viterbi_matrix(h, cm_lib) - path_lib = cm_lib.traceback() + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + assert np.isscalar(ll_tree) + nt.assert_allclose(ll_tree, ll) - # Not true in general, but let's see how far it goes - nt.assert_array_equal(path_lib, path_tree) + if compare_lib: + nt.assert_allclose(ll_check, ll) + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.ViterbiMatrix(ll_ts) + ls_hmm.viterbi_matrix(h, cm_lib) + path_lib = cm_lib.traceback() - nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) + # Not true in general, but let's see how far it goes + nt.assert_array_equal(path_lib, path_tree) - return path + nt.assert_allclose(cm_lib.normalisation_factor, cm.normalisation_factor) + + return path_tree # TODO add params to run the various checks -def check_forward_matrix(ts, h, recombination=None, mutation=None): +def check_forward_matrix( + ts, h, recombination=None, mutation=None, match_all_nodes=False, compare_lib=True +): precision = 22 h = np.array(h).astype(np.int8) n = ts.num_samples @@ -1118,28 +1252,44 @@ def check_forward_matrix(ts, h, recombination=None, mutation=None): assert np.isscalar(ll) cm = ls_forward_tree( - h, ts, recombination, mutation, scale_mutation_based_on_n_alleles=False + h, + ts, + recombination, + mutation, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=match_all_nodes, ) F2 = cm.decode() + # print(F) + # print(F2) nt.assert_allclose(F, F2) nt.assert_allclose(c, cm.normalisation_factor) ll_tree = np.sum(np.log10(cm.normalisation_factor)) nt.assert_allclose(ll_tree, ll) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.CompressedMatrix(ll_ts) - ls_hmm.forward_matrix(h, cm_lib) - F3 = cm_lib.decode() + if compare_lib: + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.forward_matrix(h, cm_lib) + F3 = cm_lib.decode() - assert_compressed_matrices_equal(cm, cm_lib) + assert_compressed_matrices_equal(cm, cm_lib) - nt.assert_allclose(F, F3) - nt.assert_allclose(c, cm_lib.normalisation_factor) - return cm_lib + nt.assert_allclose(F, F3) + nt.assert_allclose(c, cm_lib.normalisation_factor) + return cm -def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): +def check_backward_matrix( + ts, + h, + forward_cm, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_lib=True, +): precision = 22 h = np.array(h).astype(np.int8) m = ts.num_sites @@ -1166,22 +1316,23 @@ def check_backward_matrix(ts, h, forward_cm, recombination=None, mutation=None): mutation, forward_cm.normalisation_factor, precision=precision, + match_all_nodes=match_all_nodes, ) nt.assert_array_equal( backward_cm.normalisation_factor, forward_cm.normalisation_factor ) + if compare_lib: + ll_ts = ts._ll_tree_sequence + ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) + cm_lib = _tskit.CompressedMatrix(ll_ts) + ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) - ll_ts = ts._ll_tree_sequence - ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) - cm_lib = _tskit.CompressedMatrix(ll_ts) - ls_hmm.backward_matrix(h, forward_cm.normalisation_factor, cm_lib) - - assert_compressed_matrices_equal(backward_cm, cm_lib) + assert_compressed_matrices_equal(backward_cm, cm_lib) - B_lib = cm_lib.decode() - B_tree = backward_cm.decode() - nt.assert_allclose(B_tree, B_lib) - nt.assert_allclose(B, B_lib) + B_lib = cm_lib.decode() + B_tree = backward_cm.decode() + nt.assert_allclose(B_tree, B_lib) + nt.assert_allclose(B, B_lib) def add_unique_sample_mutations(ts, start=0): @@ -1221,8 +1372,8 @@ def test_match_sample(self, j): ts = self.ts() h = np.zeros(4) h[j] = 1 - path = check_viterbi(ts, h) - nt.assert_array_equal([j, j, j, j], path) + # path = check_viterbi(ts, h) + # nt.assert_array_equal([j, j, j, j], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) @@ -1262,11 +1413,48 @@ def test_switch_each_sample_missing_middle(self): h[1:3] = -1 path = check_viterbi(ts, h) # Implementation of Viterbi switches at right-most position - nt.assert_array_equal([0, 3, 3, 3], path) + nt.assert_array_equal([0, 0, 0, 3], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) +class TestSingleBalancedTreeAllSamplesExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + tables = tskit.Tree.generate_balanced(4, span=14).tree_sequence.dump_tables() + flags = tables.nodes.flags + flags[:] = 1 + tables.nodes.flags = flags + return add_unique_sample_mutations(tables.tree_sequence(), start=1) + + @pytest.mark.parametrize( + ("u", "h"), + [ + (0, [1, 0, 0, 0, 1, 0, 1]), + (1, [0, 1, 0, 0, 1, 0, 1]), + (2, [0, 0, 1, 0, 0, 1, 1]), + (3, [0, 0, 0, 1, 0, 1, 1]), + (4, [0, 0, 0, 0, 1, 0, 1]), + (5, [0, 0, 0, 0, 0, 1, 1]), + (6, [0, 0, 0, 0, 0, 0, 1]), + ], + ) + def test_match_sample(self, u, h): + np.set_printoptions(linewidth=1000, precision=3) + ts = self.ts() + path = check_viterbi(ts, h, match_all_nodes=True, compare_lib=False) + nt.assert_array_equal([u] * 7, path) + cm = check_forward_matrix(ts, h, match_all_nodes=True, compare_lib=False) + check_backward_matrix(ts, h, cm, match_all_nodes=True, compare_lib=False) + + class TestSimulationExamples: @pytest.mark.parametrize("n", [3, 10, 50]) @pytest.mark.parametrize("L", [1, 10, 100]) From 5ad8258e27e1fff7e343b6ae620aa9bd9285cc5d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 11 Oct 2023 14:55:36 +0100 Subject: [PATCH 2/6] Roughly working for all nodes (in a single tree eg) --- python/tests/test_haplotype_matching.py | 277 +++++++++++++++++------- 1 file changed, 198 insertions(+), 79 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index c3a9e3134c..a358993202 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -37,6 +37,9 @@ MISSING = -1 +# np.set_printoptions(linewidth=1000, precision=3) + + def check_alleles(alleles, m): """ Checks the specified allele list and returns a list of lists @@ -509,9 +512,24 @@ def __init__(self, ts, *args, **kwargs): def compute_normalisation_factor(self): d = {st.tree_node: st for st in self.T} N = np.zeros(self.ts.num_nodes, dtype=int) + node_count = np.zeros(self.ts.num_nodes, dtype=int) + if self.match_all_nodes: + # When matching all nodes we need to count the full + # number of nodes in that subtree + for u in self.tree.nodes(order="postorder"): + node_count[u] += 1 + for v in self.tree.children(u): + node_count[u] += node_count[v] + + else: + # When matching on samples we just count the samples. This + # is a shortcut so we can share the same code below + for u in d: + node_count[u] = self.tree.num_samples(u) + for u in self.tree.nodes(order="preorder"): if u in d: - N[u] = self.tree.num_samples(u) + N[u] = node_count[u] # Subtract this value from everything above v = self.tree.parent(u) while v != -1 and v not in d: @@ -701,7 +719,7 @@ def __init__(self, ts): def add_recombination_required(self, site, node, required): self.recombination_required.append((site, node, required)) - def choose_sample(self, site_id, tree): + def choose_switch_node(self, site_id, tree, match_all_nodes): max_value = -1 u = -1 for node, value in self.value_transitions[site_id]: @@ -710,17 +728,18 @@ def choose_sample(self, site_id, tree): u = node assert u != -1 - transition_nodes = [u for (u, _) in self.value_transitions[site_id]] - while not tree.is_sample(u): - for v in tree.children(u): - if v not in transition_nodes: - u = v - break - else: - raise AssertionError("could not find path") + if not match_all_nodes: + transition_nodes = [u for (u, _) in self.value_transitions[site_id]] + while not tree.is_sample(u): + for v in tree.children(u): + if v not in transition_nodes: + u = v + break + else: + raise AssertionError("could not find path") return u - def traceback(self): + def traceback(self, match_all_nodes=False): # Run the traceback. m = self.ts.num_sites matched = np.zeros(m, dtype=int) @@ -745,7 +764,9 @@ def traceback(self): j -= 1 if current_node == -1: - current_node = self.choose_sample(site.id, tree) + current_node = self.choose_switch_node( + site.id, tree, match_all_nodes=match_all_nodes + ) matched[site.id] = current_node # Now traverse up the tree from the current node. The first marked node @@ -1163,7 +1184,13 @@ def verify(self, ts): # TODO add params to run the various checks def check_viterbi( - ts, h, recombination=None, mutation=None, match_all_nodes=False, compare_lib=True + ts, + h, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_lib=True, + compare_lshmm=None, ): h = np.array(h).astype(np.int8) m = ts.num_sites @@ -1174,40 +1201,47 @@ def check_viterbi( mutation = np.zeros(ts.num_sites) precision = 22 - G = ts.genotype_matrix() - - path, ll = ls.viterbi( - G, - h.reshape(1, m), - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) - assert np.isscalar(ll) - # print() - # print("ls path = ", path) + if compare_lshmm is None: + # By default don't compare LSHMM with results from match_all_nodes because + # it doesn't support missing data in the ref panel. + if match_all_nodes: + compare_lshmm = False + else: + compare_lshmm = True cm = ls_viterbi_tree( h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes ) - - # Check that the likelihood of the preferred path is - # the same as ll_tree (and ll). - path_tree = cm.traceback() - ll_check = ls.path_ll( - G, - h.reshape(1, m), - path_tree, - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) + path_tree = cm.traceback(match_all_nodes=match_all_nodes) + ll_tree = np.sum(np.log10(cm.normalisation_factor)) + assert np.isscalar(ll_tree) # print(cm) # print("path tree = ", path_tree) - ll_tree = np.sum(np.log10(cm.normalisation_factor)) - assert np.isscalar(ll_tree) - nt.assert_allclose(ll_tree, ll) + if compare_lshmm: + # Check that the likelihood of the preferred path is + # the same as ll_tree (and ll). + # Missing haplotypes not supported in lshmm yet + G = ts.genotype_matrix() + path, ll = ls.viterbi( + G, + h.reshape(1, m), + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + assert np.isscalar(ll) + # print() + # print("ls path = ", path) + ll_check = ls.path_ll( + G, + h.reshape(1, m), + path_tree, + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + nt.assert_allclose(ll_tree, ll) if compare_lib: nt.assert_allclose(ll_check, ll) @@ -1227,7 +1261,13 @@ def check_viterbi( # TODO add params to run the various checks def check_forward_matrix( - ts, h, recombination=None, mutation=None, match_all_nodes=False, compare_lib=True + ts, + h, + recombination=None, + mutation=None, + match_all_nodes=False, + compare_lib=True, + compare_lshmm=None, ): precision = 22 h = np.array(h).astype(np.int8) @@ -1239,17 +1279,13 @@ def check_forward_matrix( if mutation is None: mutation = np.zeros(ts.num_sites) - G = ts.genotype_matrix() - F, c, ll = ls.forwards( - G, - h.reshape(1, m), - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) - assert F.shape == (m, n) - assert c.shape == (m,) - assert np.isscalar(ll) + if compare_lshmm is None: + # By default don't compare LSHMM with results from match_all_nodes because + # it doesn't support missing data in the ref panel. + if match_all_nodes: + compare_lshmm = False + else: + compare_lshmm = True cm = ls_forward_tree( h, @@ -1260,12 +1296,26 @@ def check_forward_matrix( match_all_nodes=match_all_nodes, ) F2 = cm.decode() - # print(F) - # print(F2) - nt.assert_allclose(F, F2) - nt.assert_allclose(c, cm.normalisation_factor) ll_tree = np.sum(np.log10(cm.normalisation_factor)) - nt.assert_allclose(ll_tree, ll) + + if compare_lshmm: + G = ts.genotype_matrix() + F, c, ll = ls.forwards( + G, + h.reshape(1, m), + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + assert F.shape == (m, n) + assert c.shape == (m,) + assert np.isscalar(ll) + + # print(F) + # print(F2) + nt.assert_allclose(F, F2) + nt.assert_allclose(c, cm.normalisation_factor) + nt.assert_allclose(ll_tree, ll) if compare_lib: ll_ts = ts._ll_tree_sequence @@ -1289,6 +1339,7 @@ def check_backward_matrix( mutation=None, match_all_nodes=False, compare_lib=True, + compare_lshmm=None, ): precision = 22 h = np.array(h).astype(np.int8) @@ -1299,15 +1350,13 @@ def check_backward_matrix( if mutation is None: mutation = np.zeros(ts.num_sites) - G = ts.genotype_matrix() - B = ls.backwards( - G, - h.reshape(1, m), - forward_cm.normalisation_factor, - recombination, - p_mutation=mutation, - scale_mutation_based_on_n_alleles=False, - ) + if compare_lshmm is None: + # By default don't compare LSHMM with results from match_all_nodes because + # it doesn't support missing data in the ref panel. + if match_all_nodes: + compare_lshmm = False + else: + compare_lshmm = True backward_cm = ls_backward_tree( h, @@ -1318,9 +1367,20 @@ def check_backward_matrix( precision=precision, match_all_nodes=match_all_nodes, ) - nt.assert_array_equal( - backward_cm.normalisation_factor, forward_cm.normalisation_factor - ) + + if compare_lshmm: + G = ts.genotype_matrix() + B = ls.backwards( + G, + h.reshape(1, m), + forward_cm.normalisation_factor, + recombination, + p_mutation=mutation, + scale_mutation_based_on_n_alleles=False, + ) + nt.assert_array_equal( + backward_cm.normalisation_factor, forward_cm.normalisation_factor + ) if compare_lib: ll_ts = ts._ll_tree_sequence ls_hmm = _tskit.LsHmm(ll_ts, recombination, mutation, precision=precision) @@ -1334,18 +1394,23 @@ def check_backward_matrix( nt.assert_allclose(B_tree, B_lib) nt.assert_allclose(B, B_lib) + return backward_cm + -def add_unique_sample_mutations(ts, start=0): +def add_unique_node_mutations(ts, start=0, nodes=None): """ Adds a mutation for each of the samples at equally spaced locations along the genome. """ + if nodes is None: + nodes = ts.samples() tables = ts.dump_tables() L = int(ts.sequence_length) - assert L % ts.num_samples == 0 - gap = L // ts.num_samples + n = len(nodes) + assert L % n == 0 + gap = L // n x = start - for u in ts.samples(): + for u in nodes: site = tables.sites.add_row(position=x, ancestral_state="0") tables.mutations.add_row(site=site, derived_state="1", node=u) x += gap @@ -1362,7 +1427,7 @@ class TestSingleBalancedTreeExample: @staticmethod def ts(): - return add_unique_sample_mutations( + return add_unique_node_mutations( tskit.Tree.generate_balanced(4, span=8).tree_sequence, start=1, ) @@ -1432,7 +1497,7 @@ def ts(): flags = tables.nodes.flags flags[:] = 1 tables.nodes.flags = flags - return add_unique_sample_mutations(tables.tree_sequence(), start=1) + return add_unique_node_mutations(tables.tree_sequence(), start=1) @pytest.mark.parametrize( ("u", "h"), @@ -1447,12 +1512,66 @@ def ts(): ], ) def test_match_sample(self, u, h): - np.set_printoptions(linewidth=1000, precision=3) ts = self.ts() - path = check_viterbi(ts, h, match_all_nodes=True, compare_lib=False) + path = check_viterbi( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True + ) nt.assert_array_equal([u] * 7, path) - cm = check_forward_matrix(ts, h, match_all_nodes=True, compare_lib=False) - check_backward_matrix(ts, h, cm, match_all_nodes=True, compare_lib=False) + cm = check_forward_matrix( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True + ) + check_backward_matrix( + ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=True + ) + + +class TestSingleBalancedTreeAllNodesExample: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + tables = tskit.Tree.generate_balanced(4, span=12).tree_sequence.dump_tables() + return add_unique_node_mutations( + tables.tree_sequence(), start=1, nodes=np.arange(len(tables.nodes) - 1) + ) + + # def test_match_sample(self, u, h): + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + # Just samples + ([1, 0, 0, 0, 1, 0], [0] * 6), + ([0, 1, 0, 0, 1, 0], [1] * 6), + ([0, 0, 1, 0, 0, 1], [2] * 6), + ([0, 0, 0, 1, 0, 1], [3] * 6), + # Switching between samples + ([1, 1, 0, 0, 1, 0], [0] + [1] * 5), + ([1, 1, 1, 0, 0, 1], [0] + [1] + [2] * 4), + # Just internal + ([0, 0, 0, 0, 1, 0], [4] * 6), + ([0, 0, 0, 0, 0, 1], [5] * 6), + ([0, 0, 0, 0, 0, 0], [6] * 6), + ], + ) + def test_match_sample(self, h, expected_path): + ts = self.ts() + path = check_viterbi( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + nt.assert_array_equal(expected_path, path) + cm = check_forward_matrix( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + print(cm.decode()) + bm = check_backward_matrix( + ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + print(bm.decode()) class TestSimulationExamples: From c72c4942969f7835a911722f73d6026f5cc397ed Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 12 Oct 2023 09:02:11 +0100 Subject: [PATCH 3/6] Partial --- python/tests/test_haplotype_matching.py | 141 ++++++++++++++++++++---- 1 file changed, 120 insertions(+), 21 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index a358993202..fb01015999 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -22,6 +22,7 @@ """ Python implementation of the Li and Stephens forwards and backwards algorithms. """ +import io import warnings import lshmm as ls @@ -37,7 +38,8 @@ MISSING = -1 -# np.set_printoptions(linewidth=1000, precision=3) +# For debugging +np.set_printoptions(linewidth=1000, precision=3) def check_alleles(alleles, m): @@ -151,7 +153,7 @@ def node_values(self): def print_state(self): print("LsHMM state") print("match_all_nodes =", self.match_all_nodes) - print("Tree =") + print("Tree = ", self.tree.index, self.tree.interval) node_labels = {} for u, value in self.node_values().items(): label = f"{u}" @@ -434,11 +436,13 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) # d1 = self.node_values() + print("PRE") + self.print_state() self.compress() # d2 = self.node_values() # assert d1 == d2 - # print("AFTER COMPRESS") - # self.print_state() + print("AFTER COMPRESS") + self.print_state() s = self.compute_normalisation_factor() for st in self.T: assert st.tree_node != tskit.NULL @@ -489,8 +493,13 @@ def run(self, h): self.initialise(1 / n) while self.tree.next(): self.update_tree() + if self.tree.index != 0: + print("AFTER UPDATE TREE") + self.print_state() for site in self.tree.sites(): self.process_site(site, h[site.id]) + print("BEFORE UPDATE TREE") + self.print_state() return self.output def compute_normalisation_factor(self): @@ -1182,7 +1191,6 @@ def verify(self, ts): self.assertAllClose(ll, ll_check) -# TODO add params to run the various checks def check_viterbi( ts, h, @@ -1212,10 +1220,10 @@ def check_viterbi( cm = ls_viterbi_tree( h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes ) + cm.print_state() path_tree = cm.traceback(match_all_nodes=match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) assert np.isscalar(ll_tree) - # print(cm) # print("path tree = ", path_tree) if compare_lshmm: @@ -1437,8 +1445,8 @@ def test_match_sample(self, j): ts = self.ts() h = np.zeros(4) h[j] = 1 - # path = check_viterbi(ts, h) - # nt.assert_array_equal([j, j, j, j], path) + path = check_viterbi(ts, h) + nt.assert_array_equal([j, j, j, j], path) cm = check_forward_matrix(ts, h) check_backward_matrix(ts, h, cm) @@ -1525,6 +1533,19 @@ def test_match_sample(self, u, h): ) +def validate_match_all_nodes(ts, h, expected_path): + path = check_viterbi( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + nt.assert_array_equal(expected_path, path) + cm = check_forward_matrix( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + bm = check_backward_matrix( + ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) + + class TestSingleBalancedTreeAllNodesExample: # 3.00┊ 6 ┊ # ┊ ┏━┻━┓ ┊ @@ -1540,7 +1561,6 @@ def ts(): tables.tree_sequence(), start=1, nodes=np.arange(len(tables.nodes) - 1) ) - # def test_match_sample(self, u, h): @pytest.mark.parametrize( ("h", "expected_path"), [ @@ -1558,20 +1578,99 @@ def ts(): ([0, 0, 0, 0, 0, 0], [6] * 6), ], ) - def test_match_sample(self, h, expected_path): - ts = self.ts() - path = check_viterbi( - ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + def test_exact_match(self, h, expected_path): + validate_match_all_nodes(self.ts(), h, expected_path) + + +class TestMultiTreeExample: + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 6 7 + @staticmethod + def ts(): + nodes = """\ + is_sample time + 1 0.000000 + 1 0.000000 + 1 0.000000 + 1 0.000000 + 0 0.041304 + 0 0.045967 + 0 0.416719 + 0 0.838075 + """ + edges = """\ + left right parent child + 0.000000 7.000000 4 1 + 0.000000 7.000000 4 2 + 0.000000 6.000000 5 0 + 0.000000 6.000000 5 4 + 6.000000 7.000000 6 0 + 6.000000 7.000000 6 3 + 0.000000 6.000000 7 3 + 6.000000 7.000000 7 4 + 0.000000 6.000000 7 5 + 6.000000 7.000000 7 6 + """ + ts = tskit.load_text( + nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False ) + return add_unique_node_mutations(ts, nodes=range(7)) + + # 0.84┊ 7 ┊ 7 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ + # 0 6 7 + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + # Just samples + ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Match root + ([0, 0, 0, 0, 0, 0, 0], [7] * 7), + ], + ) + def test_match_all_nodes(self, h, expected_path): + # print() + # print(self.ts().draw_text()) + # with open("tmp.svg", "w") as f: + # f.write(self.ts().draw_svg()) + validate_match_all_nodes(self.ts(), h, expected_path) + + @pytest.mark.parametrize( + ("h", "expected_path"), + [ + ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Switch between each of the samples + ([1, 1, 1, 1, 0, 0, 1], [0, 1, 2, 3, 3, 3, 3]), + ], + ) + def test_match_samples(self, h, expected_path): + ts = self.ts() + path = check_viterbi(ts, h) nt.assert_array_equal(expected_path, path) - cm = check_forward_matrix( - ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - print(cm.decode()) - bm = check_backward_matrix( - ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - print(bm.decode()) + cm = check_forward_matrix(ts, h) + check_backward_matrix(ts, h, cm) class TestSimulationExamples: From fa78171ca7a62b14b4c93d32e428b036d779bf08 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 30 Nov 2023 16:40:33 +0000 Subject: [PATCH 4/6] Some refactoring --- python/tests/test_haplotype_matching.py | 121 ++++++++++++++++-------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index fb01015999..e525bd0efc 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -436,13 +436,13 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) # d1 = self.node_values() - print("PRE") - self.print_state() + # print("PRE") + # self.print_state() self.compress() # d2 = self.node_values() # assert d1 == d2 - print("AFTER COMPRESS") - self.print_state() + # print("AFTER COMPRESS") + # self.print_state() s = self.compute_normalisation_factor() for st in self.T: assert st.tree_node != tskit.NULL @@ -493,13 +493,13 @@ def run(self, h): self.initialise(1 / n) while self.tree.next(): self.update_tree() - if self.tree.index != 0: - print("AFTER UPDATE TREE") - self.print_state() + # if self.tree.index != 0: + # print("AFTER UPDATE TREE") + # self.print_state() for site in self.tree.sites(): self.process_site(site, h[site.id]) - print("BEFORE UPDATE TREE") - self.print_state() + # print("BEFORE UPDATE TREE") + # self.print_state() return self.output def compute_normalisation_factor(self): @@ -1197,6 +1197,7 @@ def check_viterbi( recombination=None, mutation=None, match_all_nodes=False, + compare_fm_ll=True, compare_lib=True, compare_lshmm=None, ): @@ -1220,12 +1221,28 @@ def check_viterbi( cm = ls_viterbi_tree( h, ts, rho=recombination, mu=mutation, match_all_nodes=match_all_nodes ) - cm.print_state() + # cm.print_state() path_tree = cm.traceback(match_all_nodes=match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) assert np.isscalar(ll_tree) # print("path tree = ", path_tree) + if compare_fm_ll: + # Compare the log-likelihood of the Viterbi path (ll_tree) + # with the log-likelihood of the most likely path from + # the forward matrix. + fm = ls_forward_tree( + h, + ts, + recombination, + mutation, + scale_mutation_based_on_n_alleles=False, + match_all_nodes=match_all_nodes, + ) + ll_fm = np.sum(np.log10(fm.normalisation_factor)) + print("FMLL", ll_tree, ll_fm) + # np.testing.assert_allclose(ll_tree, ll_fm) + if compare_lshmm: # Check that the likelihood of the preferred path is # the same as ll_tree (and ll). @@ -1239,6 +1256,8 @@ def check_viterbi( scale_mutation_based_on_n_alleles=False, ) assert np.isscalar(ll) + # This is the log likelihood returned by viterbi alg + nt.assert_allclose(ll_tree, ll) # print() # print("ls path = ", path) ll_check = ls.path_ll( @@ -1249,7 +1268,9 @@ def check_viterbi( p_mutation=mutation, scale_mutation_based_on_n_alleles=False, ) - nt.assert_allclose(ll_tree, ll) + # This is the log-likelihood of the path itself, computed + # different way + nt.assert_allclose(ll_tree, ll_check) if compare_lib: nt.assert_allclose(ll_check, ll) @@ -1267,7 +1288,6 @@ def check_viterbi( return path_tree -# TODO add params to run the various checks def check_forward_matrix( ts, h, @@ -1319,8 +1339,9 @@ def check_forward_matrix( assert c.shape == (m,) assert np.isscalar(ll) - # print(F) - # print(F2) + print(ll_tree) + print(F) + print(F2) nt.assert_allclose(F, F2) nt.assert_allclose(c, cm.normalisation_factor) nt.assert_allclose(ll_tree, ll) @@ -1447,8 +1468,7 @@ def test_match_sample(self, j): h[j] = 1 path = check_viterbi(ts, h) nt.assert_array_equal([j, j, j, j], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) @pytest.mark.parametrize("j", [1, 2]) def test_match_sample_missing_flanks(self, j): @@ -1459,16 +1479,14 @@ def test_match_sample_missing_flanks(self, j): h[j] = 1 path = check_viterbi(ts, h) nt.assert_array_equal([j, j, j, j], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) path = check_viterbi(ts, h) nt.assert_array_equal([0, 1, 2, 3], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) def test_switch_each_sample_missing_flanks(self): ts = self.ts() @@ -1477,8 +1495,7 @@ def test_switch_each_sample_missing_flanks(self): h[-1] = -1 path = check_viterbi(ts, h) nt.assert_array_equal([1, 1, 2, 2], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) def test_switch_each_sample_missing_middle(self): ts = self.ts() @@ -1487,8 +1504,7 @@ def test_switch_each_sample_missing_middle(self): path = check_viterbi(ts, h) # Implementation of Viterbi switches at right-most position nt.assert_array_equal([0, 0, 0, 3], path) - cm = check_forward_matrix(ts, h) - check_backward_matrix(ts, h, cm) + check_fb_matrices(ts, h) class TestSingleBalancedTreeAllSamplesExample: @@ -1525,25 +1541,54 @@ def test_match_sample(self, u, h): ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) nt.assert_array_equal([u] * 7, path) - cm = check_forward_matrix( + fm = check_forward_matrix( ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) - check_backward_matrix( - ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=True + bm = check_backward_matrix( + ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) + check_fb_matrix_integrity(fm, bm) + + +def check_fb_matrix_integrity(fm, bm): + """ + Validate properties of the forward and backward matrices. + """ + F = fm.decode() + B = bm.decode() + assert F.shape == B.shape + for j in range(len(F)): + s = np.sum(B[j] * F[j]) + np.testing.assert_allclose(s, 1) + + +def check_fb_matrices(ts, h): + fm = check_forward_matrix(ts, h) + bm = check_backward_matrix(ts, h, fm) + check_fb_matrix_integrity(fm, bm) def validate_match_all_nodes(ts, h, expected_path): - path = check_viterbi( - ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - nt.assert_array_equal(expected_path, path) - cm = check_forward_matrix( + # path = check_viterbi( + # ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + # ) + # nt.assert_array_equal(expected_path, path) + fm = check_forward_matrix( ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False ) + F = fm.decode() + # print(cm.decode()) + # cm.print_state() bm = check_backward_matrix( - ts, h, cm, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=False ) + print("sites = ", ts.num_sites) + B = bm.decode() + print(F) + for j in range(ts.num_sites): + print(j, np.sum(B[j] * F[j])) + + # sum(B[variant,:] * F[variant,:]) = 1 class TestSingleBalancedTreeAllNodesExample: @@ -1640,11 +1685,11 @@ def ts(): [ # Just samples ([1, 0, 0, 0, 0, 1, 1], [0] * 7), - ([0, 1, 0, 0, 1, 1, 0], [1] * 7), - ([0, 0, 1, 0, 1, 1, 0], [2] * 7), - ([0, 0, 0, 1, 0, 0, 1], [3] * 7), - # Match root - ([0, 0, 0, 0, 0, 0, 0], [7] * 7), + # ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + # ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + # ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # # Match root + # ([0, 0, 0, 0, 0, 0, 0], [7] * 7), ], ) def test_match_all_nodes(self, h, expected_path): From 33300de11b6a34b10b419c34e386268976396082 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 1 Dec 2023 15:27:51 +0000 Subject: [PATCH 5/6] Tidy up --- python/tests/test_haplotype_matching.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index e525bd0efc..55cfdb7fd0 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1197,7 +1197,7 @@ def check_viterbi( recombination=None, mutation=None, match_all_nodes=False, - compare_fm_ll=True, + compare_fm_ll=False, compare_lib=True, compare_lshmm=None, ): @@ -1231,6 +1231,11 @@ def check_viterbi( # Compare the log-likelihood of the Viterbi path (ll_tree) # with the log-likelihood of the most likely path from # the forward matrix. + + # This is not always true. If the query haplotype is one + # of the actual sample haplotypes it is *almost* always + # true, but not quite. So, a useful check for development + # but not all that useful in general fm = ls_forward_tree( h, ts, @@ -1240,8 +1245,10 @@ def check_viterbi( match_all_nodes=match_all_nodes, ) ll_fm = np.sum(np.log10(fm.normalisation_factor)) - print("FMLL", ll_tree, ll_fm) - # np.testing.assert_allclose(ll_tree, ll_fm) + # print() + # print("vit ll", ll_tree) + # print("FMLL", ll_fm) + np.testing.assert_allclose(ll_tree, ll_fm) if compare_lshmm: # Check that the likelihood of the preferred path is @@ -1339,9 +1346,10 @@ def check_forward_matrix( assert c.shape == (m,) assert np.isscalar(ll) - print(ll_tree) - print(F) - print(F2) + # print(ll_tree) + # print("lshmm fm ll:", ll) + # print(F) + # print(F2) nt.assert_allclose(F, F2) nt.assert_allclose(c, cm.normalisation_factor) nt.assert_allclose(ll_tree, ll) @@ -1725,7 +1733,7 @@ def test_continuous_genome(self, n, L): ts = msprime.simulate( n, length=L, recombination_rate=1, mutation_rate=1, random_seed=42 ) - h = np.zeros(ts.num_sites, dtype=np.int8) + h = ts.genotype_matrix(samples=[0])[:, 0].T # NOTE this is a bit slow at the moment but we can disable the Python # implementation once testing has been improved on smaller examples. # Add ``compare_py=False``to these calls. From b3366cf203c2a024dad3b8128c5aa49bf3fd1c0a Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 1 Dec 2023 16:55:40 +0000 Subject: [PATCH 6/6] Progress --- python/tests/test_haplotype_matching.py | 111 ++++++++++++++---------- 1 file changed, 65 insertions(+), 46 deletions(-) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index 55cfdb7fd0..3de3a795ab 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -150,6 +150,12 @@ def node_values(self): d[u] = mapping[v] return d + @property + def matrix_size(self): + if self.match_all_nodes: + return self.ts.num_nodes + return self.ts.num_samples + def print_state(self): print("LsHMM state") print("match_all_nodes =", self.match_all_nodes) @@ -435,12 +441,18 @@ def update_probabilities(self, site, haplotype_state): def process_site(self, site, haplotype_state): self.update_probabilities(site, haplotype_state) - # d1 = self.node_values() + d1 = self.node_values() # print("PRE") - # self.print_state() + # # self.print_state() self.compress() - # d2 = self.node_values() - # assert d1 == d2 + d2 = self.node_values() + if self.match_all_nodes: + # We only get an exact match on all_nodes. For samples we just + # guarantee that the *samples* have the same value + assert d1 == d2 + else: + for u in self.ts.samples(): + assert d1[u] == d2[u] # print("AFTER COMPRESS") # self.print_state() s = self.compute_normalisation_factor() @@ -489,7 +501,7 @@ def initialise(self, value): self.T.append(ValueTransition(tree_node=u, value=value)) def run(self, h): - n = self.ts.num_samples + n = self.matrix_size self.initialise(1 / n) while self.tree.next(): self.update_tree() @@ -553,8 +565,9 @@ def compute_normalisation_factor(self): return s def compute_next_probability(self, site_id, p_last, is_match, node): + n = self.matrix_size + # print("NEXT PROBA:", site_id, n) rho = self.rho[site_id] - n = self.ts.num_samples p_e = self.compute_emission_proba(site_id, is_match) p_t = p_last * (1 - rho) + rho / n return p_t * p_e @@ -584,7 +597,7 @@ def process_site(self, site, haplotype_state, s): # compress self.compress() b_last_sum = self.compute_normalisation_factor() - n = self.ts.num_samples + n = self.matrix_size rho = self.rho[site.id] for st in self.T: if st.tree_node != tskit.NULL: @@ -624,7 +637,7 @@ def compute_normalisation_factor(self): def compute_next_probability(self, site_id, p_last, is_match, node): rho = self.rho[site_id] - n = self.ts.num_samples + n = self.matrix_size p_no_recomb = p_last * (1 - rho + rho / n) p_recomb = rho / n @@ -668,7 +681,6 @@ class CompressedMatrix: def __init__(self, ts): self.ts = ts self.num_sites = ts.num_sites - self.num_samples = ts.num_samples self.value_transitions = [None for _ in range(self.num_sites)] self.normalisation_factor = np.zeros(self.num_sites) @@ -697,14 +709,14 @@ def num_transitions(self): def get_site(self, site): return self.value_transitions[site] - def decode(self): + def decode_samples(self): """ Decodes the tree encoding of the values into an explicit matrix. """ sample_index_map = np.zeros(self.ts.num_nodes, dtype=int) - 1 sample_index_map[self.ts.samples()] = np.arange(self.ts.num_samples) - A = np.zeros((self.num_sites, self.num_samples)) + A = np.zeros((self.num_sites, self.ts.num_samples)) for tree in self.ts.trees(): for site in tree.sites(): for node, value in self.value_transitions[site.id]: @@ -713,6 +725,22 @@ def decode(self): A[site.id, j] = value return A + def decode_nodes(self): + # print("decode nodes") + A = np.zeros((self.num_sites, self.ts.num_nodes)) + for tree in self.ts.trees(): + for site in tree.sites(): + for node, value in self.value_transitions[site.id]: + # print("Decode:", site.id, node, value) + for u in tree.nodes(node): + A[site.id, u] = value + return A + + def decode(self, all_nodes=False): + if all_nodes: + return self.decode_nodes() + return self.decode_samples() + class ViterbiMatrix(CompressedMatrix): """ @@ -1330,7 +1358,7 @@ def check_forward_matrix( scale_mutation_based_on_n_alleles=False, match_all_nodes=match_all_nodes, ) - F2 = cm.decode() + F2 = cm.decode(match_all_nodes) ll_tree = np.sum(np.log10(cm.normalisation_factor)) if compare_lshmm: @@ -1549,6 +1577,7 @@ def test_match_sample(self, u, h): ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) nt.assert_array_equal([u] * 7, path) + fm = check_forward_matrix( ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=True ) @@ -1558,45 +1587,36 @@ def test_match_sample(self, u, h): check_fb_matrix_integrity(fm, bm) -def check_fb_matrix_integrity(fm, bm): +def check_fb_matrix_integrity(fm, bm, match_all_nodes=False): """ Validate properties of the forward and backward matrices. """ - F = fm.decode() - B = bm.decode() + F = fm.decode(match_all_nodes) + B = bm.decode(match_all_nodes) assert F.shape == B.shape for j in range(len(F)): s = np.sum(B[j] * F[j]) + # print(j, s) np.testing.assert_allclose(s, 1) -def check_fb_matrices(ts, h): - fm = check_forward_matrix(ts, h) - bm = check_backward_matrix(ts, h, fm) - check_fb_matrix_integrity(fm, bm) +def check_fb_matrices(ts, h, match_all_nodes=False, **kwargs): + fm = check_forward_matrix(ts, h, match_all_nodes=match_all_nodes, **kwargs) + bm = check_backward_matrix(ts, h, fm, match_all_nodes=match_all_nodes, **kwargs) + check_fb_matrix_integrity(fm, bm, match_all_nodes=match_all_nodes) def validate_match_all_nodes(ts, h, expected_path): - # path = check_viterbi( - # ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False - # ) - # nt.assert_array_equal(expected_path, path) - fm = check_forward_matrix( + # START HERE: most of this is working except for Viterbi + path = check_viterbi( ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False ) - F = fm.decode() - # print(cm.decode()) - # cm.print_state() - bm = check_backward_matrix( - ts, h, fm, match_all_nodes=True, compare_lib=False, compare_lshmm=False - ) - print("sites = ", ts.num_sites) - B = bm.decode() - print(F) - for j in range(ts.num_sites): - print(j, np.sum(B[j] * F[j])) + # print("Path = ", path) + nt.assert_array_equal(expected_path, path) - # sum(B[variant,:] * F[variant,:]) = 1 + check_fb_matrices( + ts, h, match_all_nodes=True, compare_lib=False, compare_lshmm=False + ) class TestSingleBalancedTreeAllNodesExample: @@ -1692,19 +1712,18 @@ def ts(): ("h", "expected_path"), [ # Just samples - ([1, 0, 0, 0, 0, 1, 1], [0] * 7), - # ([0, 1, 0, 0, 1, 1, 0], [1] * 7), - # ([0, 0, 1, 0, 1, 1, 0], [2] * 7), - # ([0, 0, 0, 1, 0, 0, 1], [3] * 7), - # # Match root - # ([0, 0, 0, 0, 0, 0, 0], [7] * 7), + # fails on viterbi + # ([1, 0, 0, 0, 0, 1, 1], [0] * 7), + ([0, 1, 0, 0, 1, 1, 0], [1] * 7), + ([0, 0, 1, 0, 1, 1, 0], [2] * 7), + ([0, 0, 0, 1, 0, 0, 1], [3] * 7), + # Match single internal node + ([0, 0, 0, 0, 1, 1, 0], [4] * 7), + # Match root + ([0, 0, 0, 0, 0, 0, 0], [7] * 7), ], ) def test_match_all_nodes(self, h, expected_path): - # print() - # print(self.ts().draw_text()) - # with open("tmp.svg", "w") as f: - # f.write(self.ts().draw_svg()) validate_match_all_nodes(self.ts(), h, expected_path) @pytest.mark.parametrize(