Skip to content

Commit 8ed202c

Browse files
committed
Update rate variation
Update normalisation calculation and pseudocounts for --site-specific-rates
1 parent 1198ab4 commit 8ed202c

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

model/model_dna_rate_variation.cpp

+61-26
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,6 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
262262
}
263263
}
264264
}
265-
266-
RealNumType* globalCounts = new RealNumType[matSize];
267-
for(int i = 0; i < num_states_; i++) {
268-
for(int j = 0; j < num_states_; j++) {
269-
globalCounts[row_index[i] + j] = 0;
270-
}
271-
}
272-
273265
std::stack<Index> nodeStack;
274266
const PhyloNode& root = tree->nodes[tree->root_vector_index];
275267
nodeStack.push(root.getNeighborIndex(RIGHT));
@@ -350,15 +342,14 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
350342
W[i][stateA] += branchLengthToObservation/2;
351343
W[i][stateB] += branchLengthToObservation/2;
352344
C[i][stateB + row_index[stateA]] += 1;
353-
globalCounts[stateB + row_index[stateA]] += 1;
354345
}
355346
} else {
356347
// Case 2: Last observation was the other side of the root.
357348
// In this case there are two further cases - the mutation happened either side of the root.
358349
// We calculate the relative likelihood of each case and use this to weight waiting times etc.
359350
RealNumType distToRoot = seqP_region->plength_observation2root + blength;
360351
RealNumType distToObserved = seqP_region->plength_observation2node;
361-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts);
352+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C);
362353
}
363354
} else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) {
364355
StateType stateA = seqP_region->type;
@@ -371,7 +362,6 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
371362
RealNumType prob = seqC_region->getLH(stateB);
372363
if(stateB != stateA) {
373364
C[end_pos][stateB + row_index[stateA]] += prob;
374-
globalCounts[stateB + row_index[stateA]] += prob;
375365

376366
W[end_pos][stateA] += prob * branchLengthToObservation/2;
377367
W[end_pos][stateB] += prob * branchLengthToObservation/2;
@@ -385,7 +375,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
385375
RealNumType distToObserved = seqP_region->plength_observation2node;
386376
for(StateType stateB = 0; stateB < num_states_; stateB++) {
387377
RealNumType prob = seqC_region->getLH(stateB);
388-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, prob);
378+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, prob);
389379
}
390380
}
391381
} else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) {
@@ -399,7 +389,6 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
399389
RealNumType prob = seqP_region->getLH(stateA);
400390
if(stateB != stateA) {
401391
C[end_pos][stateB + row_index[stateA]] += prob;
402-
globalCounts[stateB + row_index[stateA]] += prob;
403392

404393
W[end_pos][stateA] += prob * branchLengthToObservation/2;
405394
W[end_pos][stateB] += prob * branchLengthToObservation/2;
@@ -413,7 +402,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
413402
RealNumType distToObserved = seqP_region->plength_observation2node;
414403
for(StateType stateA = 0; stateA < num_states_; stateA++) {
415404
RealNumType prob = seqP_region->getLH(stateA);
416-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, prob);
405+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, prob);
417406
}
418407
}
419408
} else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) {
@@ -425,7 +414,6 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
425414
RealNumType probB = seqC_region->getLH(stateB);
426415
if(stateB != stateA) {
427416
C[end_pos][stateB + row_index[stateA]] += probA * probB;
428-
globalCounts[stateB + row_index[stateA]] += probA * probB;
429417

430418
W[end_pos][stateA] += probA * probB * branchLengthToObservation/2;
431419
W[end_pos][stateB] += probA * probB * branchLengthToObservation/2;
@@ -442,7 +430,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
442430
RealNumType probA = seqP_region->getLH(stateA);
443431
for(StateType stateB = 0; stateB < num_states_; stateB++) {
444432
RealNumType probB = seqC_region->getLH(stateB);
445-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, probA * probB);
433+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, probA * probB);
446434
}
447435
}
448436
}
@@ -451,22 +439,70 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
451439
}
452440
}
453441

442+
// Get genome-wide average mutation counts and waiting times
443+
RealNumType* globalCounts = new RealNumType[matSize];
444+
RealNumType* globalWaitingTimes = new RealNumType[num_states_];
445+
for(int i = 0; i < num_states_; i++) {
446+
globalWaitingTimes[i] = 0;
447+
for(int j = 0; j < num_states_; j++) {
448+
globalCounts[row_index[i] + j] = 0;
449+
}
450+
}
451+
452+
for(int i = 0; i < genomeSize; i++) {
453+
for(int j = 0; j < num_states_; j++) {
454+
globalWaitingTimes[j] += W[i][j];
455+
for(int k = 0; k < num_states_; k++) {
456+
globalCounts[row_index[j] + k] += C[i][row_index[j] + k];
457+
}
458+
}
459+
}
460+
454461
for(int j = 0; j < num_states_; j++) {
462+
globalWaitingTimes[j] /= genomeSize;
455463
for(int k = 0; k < num_states_; k++) {
456-
(globalCounts[row_index[j] + k] /= genomeSize);
464+
globalCounts[row_index[j] + k] /= genomeSize;
457465
}
458466
}
459467

460-
printMatrix(globalCounts, &std::cout);
461468
// add pseudocount of average rate across genome * waitingTime pseudocount for counts
462469
for(int i = 0; i < genomeSize; i++) {
463470
for(int j = 0; j < num_states_; j++) {
464471
for(int k = 0; k < num_states_; k++) {
465-
C[i][row_index[j] + k] += globalCounts[row_index[j] + k] * waitingTimePseudoCount ;
472+
C[i][row_index[j] + k] += globalCounts[row_index[j] + k] * waitingTimePseudoCount / globalWaitingTimes[j];
466473
}
467474
}
468475
}
469476

477+
if(cmaple::verbose_mode > VB_MIN)
478+
{
479+
RealNumType* referenceFreqs = new RealNumType[num_states_];
480+
for(int j = 0; j < num_states_; j++) {
481+
referenceFreqs[j] = 0;
482+
}
483+
for(int i = 0; i < genomeSize; i++) {
484+
referenceFreqs[tree->aln->ref_seq[i]]++;
485+
}
486+
for(int j = 0; j < num_states_; j++) {
487+
referenceFreqs[j] /= genomeSize;
488+
}
489+
490+
std::cout << "Genome-wide average waiting times:\t\t";
491+
for(int j = 0; j < num_states_; j++) {
492+
std::cout << globalWaitingTimes[j] << "\t" ;
493+
}
494+
std::cout << std::endl;
495+
std::cout << "Reference nucleotide frequencies:\t\t";
496+
for(int j = 0; j < num_states_; j++) {
497+
std::cout << referenceFreqs[j] << "\t" ;
498+
}
499+
std::cout << std::endl;
500+
delete[] referenceFreqs;
501+
}
502+
503+
delete[] globalCounts;
504+
delete[] globalWaitingTimes;
505+
470506
RealNumType totalRate = 0;
471507
// Update mutation matrices with new rate estimation
472508
for(int i = 0; i < genomeSize; i++) {
@@ -478,7 +514,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
478514
for(int stateB = 0; stateB < num_states_; stateB++) {
479515
if(stateA != stateB) {
480516
RealNumType newRate = Ci[stateB + row_index[stateA]] / Wi[stateA];
481-
newRate = MIN(250.0, MAX(0.01, newRate));
517+
newRate = MIN(250.0, MAX(0.001, newRate));
482518
mutationMatrices[i * matSize + (stateB + row_index[stateA])] = newRate;
483519

484520
// Approximate total rate by considering rates from reference nucleotide
@@ -505,14 +541,16 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
505541
}
506542

507543
// Normalise entries of mutation matrices
508-
RealNumType averageRate = totalRate / genomeSize;
544+
totalRate /= genomeSize;
545+
//RealNumType averageRate = totalRate / genomeSize;
509546
for(int i = 0; i < genomeSize; i++) {
510547
for(int stateA = 0; stateA < num_states_; stateA++) {
511548
RealNumType rowSum = 0;
512549
for(int stateB = 0; stateB < num_states_; stateB++) {
513550
if(stateA != stateB) {
514551
RealNumType val = mutationMatrices[i * matSize + (stateB + row_index[stateA])];
515-
val /= averageRate;
552+
//val /= averageRate;
553+
val /= totalRate;
516554

517555
mutationMatrices[i * matSize + (stateB + row_index[stateA])] = val;
518556
transposedMutationMatrices[i * matSize + (stateA + row_index[stateB])] = val;
@@ -541,15 +579,13 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
541579
}
542580
delete[] C;
543581
delete[] W;
544-
delete[] globalCounts;
545-
546582
}
547583

548584
void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end,
549585
StateType parentState, StateType childState,
550586
RealNumType distToRoot, RealNumType distToObserved,
551587
RealNumType** waitingTimes, RealNumType** counts,
552-
RealNumType* globalCounts, RealNumType weight)
588+
RealNumType weight)
553589
{
554590
if(parentState != childState) {
555591
for(int i = start; i <= end; i++) {
@@ -559,7 +595,6 @@ void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType
559595
waitingTimes[i][parentState] += weight * relativeRootIsStateParent * distToRoot/2;
560596
waitingTimes[i][childState] += weight * relativeRootIsStateParent * distToRoot/2;
561597
counts[i][childState + row_index[parentState]] += weight * relativeRootIsStateParent;
562-
globalCounts[childState + row_index[parentState]] += weight * relativeRootIsStateParent;
563598

564599
RealNumType relativeRootIsStateChild = 1 - relativeRootIsStateParent;
565600
waitingTimes[i][childState] += weight * relativeRootIsStateChild * distToRoot;

model/model_dna_rate_variation.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class ModelDNARateVariation : public ModelDNA {
8383
StateType parentState, StateType childState,
8484
RealNumType distToRoot, RealNumType distToObserved,
8585
RealNumType** waitingTimes, RealNumType** counts,
86-
RealNumType* globalCounts, RealNumType weight = 1.);
86+
RealNumType weight = 1.);
8787

8888
cmaple::PositionType genomeSize;
8989

0 commit comments

Comments
 (0)