Skip to content

Commit b6f9872

Browse files
Merge pull request #2799 from jeromekelleher/ls-forward-native
LS model refactoring
2 parents 11dd2f5 + 64aedbf commit b6f9872

11 files changed

+854
-451
lines changed

c/tests/test_haplotype_matching.c

+36-55
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* MIT License
33
*
4-
* Copyright (c) 2019-2022 Tskit Developers
4+
* Copyright (c) 2019-2023 Tskit Developers
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -28,46 +28,6 @@
2828
#include <unistd.h>
2929
#include <stdlib.h>
3030

31-
/****************************************************************
32-
* TestHMM
33-
****************************************************************/
34-
35-
static double
36-
tsk_ls_hmm_compute_normalisation_factor_site_test(tsk_ls_hmm_t *TSK_UNUSED(self))
37-
{
38-
return 1.0;
39-
}
40-
41-
static int
42-
tsk_ls_hmm_next_probability_test(tsk_ls_hmm_t *TSK_UNUSED(self),
43-
tsk_id_t TSK_UNUSED(site_id), double TSK_UNUSED(p_last), bool TSK_UNUSED(is_match),
44-
tsk_id_t TSK_UNUSED(node), double *result)
45-
{
46-
*result = rand();
47-
/* printf("next proba = %f\n", *result); */
48-
return 0;
49-
}
50-
51-
static int
52-
run_test_hmm(tsk_ls_hmm_t *hmm, int32_t *haplotype, tsk_compressed_matrix_t *output)
53-
{
54-
int ret = 0;
55-
56-
srand(1);
57-
58-
ret = tsk_ls_hmm_run(hmm, haplotype, tsk_ls_hmm_next_probability_test,
59-
tsk_ls_hmm_compute_normalisation_factor_site_test, output);
60-
if (ret != 0) {
61-
goto out;
62-
}
63-
out:
64-
return ret;
65-
}
66-
67-
/****************************************************************
68-
* TestHMM
69-
****************************************************************/
70-
7131
static void
7232
test_single_tree_missing_alleles(void)
7333
{
@@ -206,6 +166,7 @@ test_single_tree_match_impossible(void)
206166
tsk_treeseq_t ts;
207167
tsk_ls_hmm_t ls_hmm;
208168
tsk_compressed_matrix_t forward;
169+
tsk_compressed_matrix_t backward;
209170
tsk_viterbi_matrix_t viterbi;
210171

211172
double rho[] = { 0.0, 0.25, 0.25 };
@@ -228,8 +189,16 @@ test_single_tree_match_impossible(void)
228189
tsk_viterbi_matrix_print_state(&viterbi, _devnull);
229190
tsk_ls_hmm_print_state(&ls_hmm, _devnull);
230191

192+
ret = tsk_ls_hmm_backward(&ls_hmm, h, forward.normalisation_factor, &backward, 0);
193+
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_MATCH_IMPOSSIBLE);
194+
tsk_compressed_matrix_print_state(&backward, _devnull);
195+
/* tsk_compressed_matrix_print_state(&forward, stdout); */
196+
/* tsk_compressed_matrix_print_state(&backward, stdout); */
197+
tsk_ls_hmm_print_state(&ls_hmm, _devnull);
198+
231199
tsk_ls_hmm_free(&ls_hmm);
232200
tsk_compressed_matrix_free(&forward);
201+
tsk_compressed_matrix_free(&backward);
233202
tsk_viterbi_matrix_free(&viterbi);
234203
tsk_treeseq_free(&ts);
235204
}
@@ -275,12 +244,15 @@ test_single_tree_errors(void)
275244
ret = tsk_compressed_matrix_store_site(&forward, 4, 0, 0, NULL);
276245
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_SITE_OUT_OF_BOUNDS);
277246

278-
T[0].tree_node = -1;
279-
T[0].value = 0;
280-
ret = tsk_compressed_matrix_store_site(&forward, 0, 1, 1, T);
281-
CU_ASSERT_EQUAL_FATAL(ret, 0);
282-
ret = tsk_compressed_matrix_decode(&forward, (double *) decoded);
283-
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS);
247+
/* FIXME disabling this tests for now because we filter out negative
248+
* nodes when storing now, to accomodate some oddness in the initial
249+
* conditions of the backward matrix. */
250+
/* T[0].tree_node = -1; */
251+
/* T[0].value = 0; */
252+
/* ret = tsk_compressed_matrix_store_site(&forward, 0, 1, 1, T); */
253+
/* CU_ASSERT_EQUAL_FATAL(ret, 0); */
254+
/* ret = tsk_compressed_matrix_decode(&forward, (double *) decoded); */
255+
/* CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_NODE_OUT_OF_BOUNDS); */
284256

285257
T[0].tree_node = 7;
286258
T[0].value = 0;
@@ -443,7 +415,7 @@ test_multi_tree_exact_match(void)
443415
int ret = 0;
444416
tsk_treeseq_t ts;
445417
tsk_ls_hmm_t ls_hmm;
446-
tsk_compressed_matrix_t forward;
418+
tsk_compressed_matrix_t forward, backward;
447419
tsk_viterbi_matrix_t viterbi;
448420

449421
double rho[] = { 0.0, 0.25, 0.25 };
@@ -465,6 +437,13 @@ test_multi_tree_exact_match(void)
465437
ret = tsk_compressed_matrix_decode(&forward, decoded_compressed_matrix);
466438
CU_ASSERT_EQUAL_FATAL(ret, 0);
467439

440+
ret = tsk_ls_hmm_backward(&ls_hmm, h, forward.normalisation_factor, &backward, 0);
441+
CU_ASSERT_EQUAL_FATAL(ret, 0);
442+
tsk_ls_hmm_print_state(&ls_hmm, _devnull);
443+
tsk_compressed_matrix_print_state(&backward, _devnull);
444+
ret = tsk_compressed_matrix_decode(&backward, decoded_compressed_matrix);
445+
CU_ASSERT_EQUAL_FATAL(ret, 0);
446+
468447
ret = tsk_ls_hmm_viterbi(&ls_hmm, h, &viterbi, 0);
469448
CU_ASSERT_EQUAL_FATAL(ret, 0);
470449
tsk_viterbi_matrix_print_state(&viterbi, _devnull);
@@ -492,6 +471,7 @@ test_multi_tree_exact_match(void)
492471

493472
tsk_ls_hmm_free(&ls_hmm);
494473
tsk_compressed_matrix_free(&forward);
474+
tsk_compressed_matrix_free(&backward);
495475
tsk_viterbi_matrix_free(&viterbi);
496476
tsk_treeseq_free(&ts);
497477
}
@@ -529,7 +509,8 @@ test_caterpillar_tree_many_values(void)
529509
int ret = 0;
530510
tsk_ls_hmm_t ls_hmm;
531511
tsk_compressed_matrix_t matrix;
532-
double unused[] = { 0, 0, 0, 0, 0 };
512+
double rho[] = { 0.1, 0.1, 0.1, 0.1, 0.1 };
513+
double mu[] = { 0.0, 0.0, 0.0, 0.0, 0.0 };
533514
int32_t h[] = { 0, 0, 0, 0, 0 };
534515
tsk_size_t n[] = {
535516
8,
@@ -542,11 +523,11 @@ test_caterpillar_tree_many_values(void)
542523

543524
for (j = 0; j < sizeof(n) / sizeof(*n); j++) {
544525
ts = caterpillar_tree(n[j], 5, n[j] - 2);
545-
ret = tsk_ls_hmm_init(&ls_hmm, ts, unused, unused, 0);
526+
ret = tsk_ls_hmm_init(&ls_hmm, ts, rho, mu, 0);
546527
CU_ASSERT_EQUAL_FATAL(ret, 0);
547528
ret = tsk_compressed_matrix_init(&matrix, ts, 1 << 10, 0);
548529
CU_ASSERT_EQUAL_FATAL(ret, 0);
549-
ret = run_test_hmm(&ls_hmm, h, &matrix);
530+
ret = tsk_ls_hmm_forward(&ls_hmm, h, &matrix, TSK_NO_INIT);
550531
CU_ASSERT_EQUAL_FATAL(ret, 0);
551532
tsk_compressed_matrix_print_state(&matrix, _devnull);
552533
tsk_ls_hmm_print_state(&ls_hmm, _devnull);
@@ -559,13 +540,13 @@ test_caterpillar_tree_many_values(void)
559540

560541
j = 40;
561542
ts = caterpillar_tree(j, 5, j - 2);
562-
ret = tsk_ls_hmm_init(&ls_hmm, ts, unused, unused, 0);
543+
ret = tsk_ls_hmm_init(&ls_hmm, ts, rho, mu, 0);
563544
CU_ASSERT_EQUAL_FATAL(ret, 0);
564545
ret = tsk_compressed_matrix_init(&matrix, ts, 1 << 20, 0);
565546
CU_ASSERT_EQUAL_FATAL(ret, 0);
566-
/* Short circuit this value so we can run the test in reasonable time */
567-
ls_hmm.max_parsimony_words = 1;
568-
ret = run_test_hmm(&ls_hmm, h, &matrix);
547+
/* Short circuit this value so we can run the test */
548+
ls_hmm.max_parsimony_words = 0;
549+
ret = tsk_ls_hmm_forward(&ls_hmm, h, &matrix, TSK_NO_INIT);
569550
CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_TOO_MANY_VALUES);
570551

571552
tsk_ls_hmm_free(&ls_hmm);

0 commit comments

Comments
 (0)