Skip to content

Commit 6d07b03

Browse files
committed
Fixes to rate-variation
Fix positional errors causing incorrect likelihood calculations when using rate-variation.
1 parent d2e8929 commit 6d07b03

File tree

5 files changed

+147
-74
lines changed

5 files changed

+147
-74
lines changed

alignment/seqregions.h

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ auto updateLHwithModel(const ModelBase* model,
564564
RealNumType tot = 0;
565565
if (total_blength > 0) // TODO: avoid
566566
{
567+
567568
const RealNumType* mutation_mat_row = model->getMutationMatrixRow(i, pos);
568569
tot += dotProduct<num_states>(&(prior)[0], mutation_mat_row);
569570

@@ -630,7 +631,6 @@ void merge_N_O(const RealNumType lower_plength,
630631
total_blength = reg_o.plength_observation2node +
631632
(lower_plength > 0 ? lower_plength : 0);
632633
}
633-
634634
auto new_lh =
635635
cmaple::make_unique<SeqRegion::LHType>(); // = new RealNumType[num_states];
636636
RealNumType sum_lh = updateLHwithModel<num_states>(model, *reg_o.likelihood,
@@ -765,7 +765,7 @@ void merge_RACGT_O(const SeqRegion& seq2_region,
765765
assert(seq2_region.type == TYPE_O);
766766
assert(model);
767767
assert(aln);
768-
768+
769769
RealNumType sum_new_lh = updateMultLHwithMat<num_states>(
770770
model->getMutationMatrix(end_pos), *(seq2_region.likelihood), new_lh, total_blength_2);
771771

@@ -1043,7 +1043,7 @@ auto merge_O_O_TwoLowers(const SeqRegion& seq2_region,
10431043
assert(seq2_region.type == TYPE_O);
10441044
assert(model);
10451045
assert(aln);
1046-
1046+
10471047
RealNumType sum_lh = updateMultLHwithMat<num_states>(
10481048
model->getMutationMatrix(end_pos), *seq2_region.likelihood, new_lh, total_blength_2);
10491049

@@ -1179,7 +1179,7 @@ auto merge_RACGT_O_TwoLowers(const SeqRegion& seq2_region,
11791179
assert(seq2_region.type == TYPE_O);
11801180
assert(model);
11811181
assert(aln);
1182-
1182+
11831183
RealNumType sum_lh = updateMultLHwithMat<num_states>(
11841184
model->getMutationMatrix(end_pos), *(seq2_region.likelihood), new_lh, total_blength_2);
11851185

@@ -1363,7 +1363,7 @@ auto merge_notN_notN_TwoLowers(const SeqRegion& seq1_region,
13631363
aln, model, threshold_prob, log_lh, merged_regions, return_log_lh);
13641364
return ret;
13651365
}
1366-
1366+
13671367
// no error
13681368
return true;
13691369
}
@@ -1447,12 +1447,6 @@ RealNumType SeqRegions::mergeTwoLowers(
14471447
}
14481448
// neither seq1_entry nor seq2_entry = N
14491449
else {
1450-
//RealNumType* vec = (&seq1_region->likelihood)[0];
1451-
/*if(!seq1_region->likelihood.get()) {
1452-
std::cout << "[mergeTwoLowers] Likelihood pointer: " << seq1_region->likelihood.get() << std::endl;
1453-
return MIN_NEGATIVE;
1454-
}*/
1455-
//std::cout << "[mergeTwoLowers] Prior: " << vec[0] << " " << vec[1] << " " << vec[2] << " " << vec[3] << std::endl;
14561450
if (!merge_notN_notN_TwoLowers<num_states>(
14571451
*seq1_region, *seq2_region, plength1, plength2, end_pos, pos, aln,
14581452
model, cumulative_rate, threshold_prob, log_lh, merged_regions,
@@ -1480,7 +1474,6 @@ RealNumType SeqRegions::mergeTwoLowers(
14801474
max_elements); // ensure we did the correct reserve, otherwise it was
14811475
// a pessimization
14821476
#endif
1483-
14841477
return log_lh;
14851478
}
14861479

@@ -1505,7 +1498,7 @@ auto SeqRegions::computeAbsoluteLhAtRoot(
15051498
for (StateType i = 0; i < num_states; ++i) {
15061499
log_lh += model->getRootLogFreq(i) *
15071500
(cumulative_base[static_cast<size_t>(region.position) + 1][i] -
1508-
cumulative_base[static_cast<size_t>(start_pos)][i]);
1501+
cumulative_base[static_cast<size_t>(start_pos)][i]);
15091502
}
15101503
}
15111504
// type ACGT
@@ -1737,7 +1730,7 @@ bool calSiteLhs_O_O(std::vector<RealNumType>& site_lh_contributions,
17371730
assert(seq2_region.type == TYPE_O);
17381731
assert(aln);
17391732
assert(model);
1740-
1733+
17411734
RealNumType sum_lh = updateMultLHwithMat<num_states>(
17421735
model->getMutationMatrix(end_pos), *seq2_region.likelihood, new_lh, total_blength_2);
17431736

@@ -1869,7 +1862,6 @@ bool calSiteLhs_RACGT_O(std::vector<RealNumType>& site_lh_contributions,
18691862
assert(seq2_region.type == TYPE_O);
18701863
assert(aln);
18711864
assert(model);
1872-
18731865
RealNumType sum_lh = updateMultLHwithMat<num_states>(
18741866
model->getMutationMatrix(end_pos), *(seq2_region.likelihood), new_lh, total_blength_2);
18751867

model/model_dna_rate_variation.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ void ModelDNARateVariation::printCountsAndWaitingTimes(const RealNumType* counts
5959
}
6060

6161
bool ModelDNARateVariation::updateMutationMatEmpirical() {
62+
if(ratesEstimated) {
63+
std::cout << "[ModelDNARateVariation] Warning: Overwriting estimated rate matrices with single empirical mutation matrix." << std::endl;
64+
}
6265
bool val = ModelDNA::updateMutationMatEmpirical();
6366
for(int i = 0; i < genomeSize; i++) {
6467
for(int j = 0; j < matSize; j++) {
@@ -76,6 +79,7 @@ bool ModelDNARateVariation::updateMutationMatEmpirical() {
7679
}
7780

7881
void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) {
82+
ratesEstimated = true;
7983
if(useSiteRates) {
8084
estimateRatePerSite(tree);
8185

@@ -109,21 +113,22 @@ void ModelDNARateVariation::estimateRates(cmaple::Tree* tree) {
109113
outFile << "Rate Matrix: " << std::endl;
110114
printMatrix(getMutationMatrix(i), &outFile);
111115
outFile << std::endl;
116+
112117
}
113118
outFile.close();
114119
}
115120
}
116121

117122
void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
118123
std::cout << "Estimating mutation rate per site..." << std::endl;
119-
RealNumType** totals = new RealNumType*[num_states_];
124+
RealNumType** waitingTimes = new RealNumType*[num_states_];
120125
for(int j = 0; j < num_states_; j++) {
121-
totals[j] = new RealNumType[genomeSize];
126+
waitingTimes[j] = new RealNumType[genomeSize];
122127
}
123128
RealNumType* numSubstitutions = new RealNumType[genomeSize];
124129
for(int i = 0; i < genomeSize; i++) {
125130
for(int j = 0; j < num_states_; j++) {
126-
totals[j][i] = 0;
131+
waitingTimes[j][i] = 0;
127132
}
128133
numSubstitutions[i] = 0;
129134
}
@@ -170,12 +175,12 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
170175
// both states are type REF
171176
for(int i = pos; i <= end_pos; i++) {
172177
int state = tree->aln->ref_seq[static_cast<std::vector<cmaple::StateType>::size_type>(i)];
173-
totals[state][i] += blength;
178+
waitingTimes[state][i] += blength;
174179
}
175180
} else if(seq1_region->type == seq2_region->type && seq1_region->type < TYPE_R) {
176181
// both states are equal but not of type REF
177182
for(int i = pos; i <= end_pos; i++) {
178-
totals[seq1_region->type][i] += blength;
183+
waitingTimes[seq1_region->type][i] += blength;
179184
}
180185
} else if(seq1_region->type <= TYPE_R && seq2_region->type <= TYPE_R) {
181186
// both states are not equal
@@ -194,7 +199,7 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
194199
} else {
195200
RealNumType expectedRateNoSubstitution = 0;
196201
for(int j = 0; j < num_states_; j++) {
197-
RealNumType summand = totals[j][i] * abs(diagonal_mut_mat[j]);
202+
RealNumType summand = waitingTimes[j][i] * abs(diagonal_mut_mat[j]);
198203
expectedRateNoSubstitution += summand;
199204
}
200205
if(expectedRateNoSubstitution <= 0.01) {
@@ -210,7 +215,6 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
210215
RealNumType averageRate = rateCount / genomeSize;
211216
for(int i = 0; i < genomeSize; i++) {
212217
rates[i] /= averageRate;
213-
//std::cout << rates[i] << " ";
214218
for(int stateA = 0; stateA < num_states_; stateA++) {
215219
RealNumType rowSum = 0;
216220
for(int stateB = 0; stateB < num_states_; stateB++) {
@@ -236,9 +240,9 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
236240
}
237241

238242
for(int j = 0; j < num_states_; j++) {
239-
delete[] totals[j];
243+
delete[] waitingTimes[j];
240244
}
241-
delete[] totals;
245+
delete[] waitingTimes;
242246
delete[] numSubstitutions;
243247
}
244248

@@ -556,4 +560,14 @@ void ModelDNARateVariation::setAllMatricesToDefault() {
556560

557561
}
558562
}
563+
}
564+
565+
void ModelDNARateVariation::setMatrixAtPosition(RealNumType* matrix, PositionType i) {
566+
for(int stateA = 0; stateA < num_states_; stateA++) {
567+
diagonalMutationMatrices[i * num_states_ + stateA] = matrix[stateA + row_index[stateA]];
568+
for(int stateB = 0; stateB < num_states_; stateB++) {
569+
mutationMatrices[i * matSize + (stateB + row_index[stateA])] = matrix[stateB + row_index[stateA]];
570+
transposedMutationMatrices[i * matSize + (stateB + row_index[stateA])] = matrix[stateA + row_index[stateB]];
571+
}
572+
}
559573
}

model/model_dna_rate_variation.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class ModelDNARateVariation : public ModelDNA {
7272
virtual bool updateMutationMatEmpirical() override;
7373

7474
void setAllMatricesToDefault();
75+
void setMatrixAtPosition(RealNumType* matrix, PositionType i);
7576

7677
void printMatrix(const RealNumType* matrix, std::ostream* outStream);
7778
void printCountsAndWaitingTimes(const RealNumType* counts, const RealNumType* waitingTImes, std::ostream* outStream);
@@ -94,6 +95,7 @@ class ModelDNARateVariation : public ModelDNA {
9495
cmaple::RealNumType* rates = nullptr;
9596
uint16_t matSize;
9697
bool useSiteRates = false;
98+
bool ratesEstimated = false;
9799

98100
};
99101
}

0 commit comments

Comments
 (0)