4
4
5
5
using namespace cmaple ;
6
6
7
- ModelDNARateVariation::ModelDNARateVariation (const cmaple::ModelBase::SubModel sub_model, PositionType _genomeSize, bool _useSiteRates)
7
+ ModelDNARateVariation::ModelDNARateVariation ( const cmaple::ModelBase::SubModel sub_model, PositionType _genomeSize,
8
+ bool _useSiteRates, cmaple::RealNumType _wtPseudocount)
8
9
: ModelDNA(sub_model) {
9
10
10
11
genomeSize = _genomeSize;
11
12
useSiteRates = _useSiteRates;
12
13
matSize = row_index[num_states_];
14
+ waitingTimePseudoCount = _wtPseudocount;
13
15
14
16
mutationMatrices = new RealNumType[matSize * genomeSize]();
15
17
transposedMutationMatrices = new RealNumType[matSize * genomeSize]();
@@ -248,21 +250,26 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
248
250
249
251
void ModelDNARateVariation::estimateRatesPerSitePerEntry (cmaple::Tree* tree) {
250
252
251
- // Possibly better to keep this memory allocated and reuse
252
- // since this will be called several times?
253
253
RealNumType** C = new RealNumType*[genomeSize];
254
254
RealNumType** W = new RealNumType*[genomeSize];
255
255
for (int i = 0 ; i < genomeSize; i++) {
256
256
C[i] = new RealNumType[matSize];
257
257
W[i] = new RealNumType[num_states_];
258
258
for (int j = 0 ; j < num_states_; j++) {
259
- W[i][j] = 0.0001 ;
259
+ W[i][j] = waitingTimePseudoCount ;
260
260
for (int k = 0 ; k < num_states_; k++) {
261
261
C[i][row_index[j] + k] = 0 ;
262
262
}
263
263
}
264
264
}
265
265
266
+ RealNumType* globalCounts = new RealNumType[matSize];
267
+ for (int i = 0 ; i < num_states_; i++) {
268
+ for (int j = 0 ; j < num_states_; j++) {
269
+ globalCounts[row_index[i] + j] = 0 ;
270
+ }
271
+ }
272
+
266
273
std::stack<Index> nodeStack;
267
274
const PhyloNode& root = tree->nodes [tree->root_vector_index ];
268
275
nodeStack.push (root.getNeighborIndex (RIGHT));
@@ -343,14 +350,15 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
343
350
W[i][stateA] += branchLengthToObservation/2 ;
344
351
W[i][stateB] += branchLengthToObservation/2 ;
345
352
C[i][stateB + row_index[stateA]] += 1 ;
353
+ globalCounts[stateB + row_index[stateA]] += 1 ;
346
354
}
347
355
} else {
348
356
// Case 2: Last observation was the other side of the root.
349
357
// In this case there are two further cases - the mutation happened either side of the root.
350
358
// We calculate the relative likelihood of each case and use this to weight waiting times etc.
351
359
RealNumType distToRoot = seqP_region->plength_observation2root + blength;
352
360
RealNumType distToObserved = seqP_region->plength_observation2node ;
353
- updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C);
361
+ updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts );
354
362
}
355
363
} else if (seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) {
356
364
StateType stateA = seqP_region->type ;
@@ -363,6 +371,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
363
371
RealNumType prob = seqC_region->getLH (stateB);
364
372
if (stateB != stateA) {
365
373
C[end_pos][stateB + row_index[stateA]] += prob;
374
+ globalCounts[stateB + row_index[stateA]] += prob;
375
+
366
376
W[end_pos][stateA] += prob * branchLengthToObservation/2 ;
367
377
W[end_pos][stateB] += prob * branchLengthToObservation/2 ;
368
378
} else {
@@ -375,7 +385,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
375
385
RealNumType distToObserved = seqP_region->plength_observation2node ;
376
386
for (StateType stateB = 0 ; stateB < num_states_; stateB++) {
377
387
RealNumType prob = seqC_region->getLH (stateB);
378
- updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, prob);
388
+ updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, prob);
379
389
}
380
390
}
381
391
} else if (seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) {
@@ -389,6 +399,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
389
399
RealNumType prob = seqP_region->getLH (stateA);
390
400
if (stateB != stateA) {
391
401
C[end_pos][stateB + row_index[stateA]] += prob;
402
+ globalCounts[stateB + row_index[stateA]] += prob;
403
+
392
404
W[end_pos][stateA] += prob * branchLengthToObservation/2 ;
393
405
W[end_pos][stateB] += prob * branchLengthToObservation/2 ;
394
406
} else {
@@ -401,7 +413,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
401
413
RealNumType distToObserved = seqP_region->plength_observation2node ;
402
414
for (StateType stateA = 0 ; stateA < num_states_; stateA++) {
403
415
RealNumType prob = seqP_region->getLH (stateA);
404
- updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, prob);
416
+ updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, prob);
405
417
}
406
418
}
407
419
} else if (seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) {
@@ -413,6 +425,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
413
425
RealNumType probB = seqC_region->getLH (stateB);
414
426
if (stateB != stateA) {
415
427
C[end_pos][stateB + row_index[stateA]] += probA * probB;
428
+ globalCounts[stateB + row_index[stateA]] += probA * probB;
429
+
416
430
W[end_pos][stateA] += probA * probB * branchLengthToObservation/2 ;
417
431
W[end_pos][stateB] += probA * probB * branchLengthToObservation/2 ;
418
432
} else {
@@ -428,7 +442,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
428
442
RealNumType probA = seqP_region->getLH (stateA);
429
443
for (StateType stateB = 0 ; stateB < num_states_; stateB++) {
430
444
RealNumType probB = seqC_region->getLH (stateB);
431
- updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, probA * probB);
445
+ updateCountsAndWaitingTimesAcrossRoot (pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, probA * probB);
432
446
}
433
447
}
434
448
}
@@ -437,26 +451,40 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
437
451
}
438
452
}
439
453
440
- RealNumType totalRate = 0 ;
454
+ for (int j = 0 ; j < num_states_; j++) {
455
+ for (int k = 0 ; k < num_states_; k++) {
456
+ (globalCounts[row_index[j] + k] /= genomeSize);
457
+ }
458
+ }
441
459
460
+ printMatrix (globalCounts, &std::cout);
461
+ // add pseudocount of average rate across genome * waitingTime pseudocount for counts
462
+ for (int i = 0 ; i < genomeSize; i++) {
463
+ for (int j = 0 ; j < num_states_; j++) {
464
+ for (int k = 0 ; k < num_states_; k++) {
465
+ C[i][row_index[j] + k] += globalCounts[row_index[j] + k] * waitingTimePseudoCount ;
466
+ }
467
+ }
468
+ }
469
+
470
+ RealNumType totalRate = 0 ;
442
471
// Update mutation matrices with new rate estimation
443
472
for (int i = 0 ; i < genomeSize; i++) {
444
473
RealNumType* Ci = C[i];
445
474
RealNumType* Wi = W[i];
475
+ StateType refState = tree->aln ->ref_seq [static_cast <std::vector<cmaple::StateType>::size_type>(i)];
476
+
446
477
for (int stateA = 0 ; stateA < num_states_; stateA++) {
447
478
for (int stateB = 0 ; stateB < num_states_; stateB++) {
448
479
if (stateA != stateB) {
449
- RealNumType newRate;
450
- if (Ci[stateB + row_index[stateA]] == 0 ) {
451
- newRate = 0.001 ;
452
- } else if (Wi[stateA] <= 0.01 ) {
453
- newRate = 1.0 ;
454
- } else {
455
- newRate = Ci[stateB + row_index[stateA]] / Wi[stateA];
456
- newRate = MIN (100.0 , MAX (0.0001 , newRate));
457
- }
480
+ RealNumType newRate = Ci[stateB + row_index[stateA]] / Wi[stateA];
481
+ newRate = MIN (250.0 , MAX (0.01 , newRate));
458
482
mutationMatrices[i * matSize + (stateB + row_index[stateA])] = newRate;
459
- totalRate += newRate;
483
+
484
+ // Approximate total rate by considering rates from reference nucleotide
485
+ if (refState == stateA) {
486
+ totalRate += newRate;
487
+ }
460
488
}
461
489
}
462
490
}
@@ -513,14 +541,15 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
513
541
}
514
542
delete[] C;
515
543
delete[] W;
544
+ delete[] globalCounts;
516
545
517
546
}
518
547
519
548
void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot ( PositionType start, PositionType end,
520
549
StateType parentState, StateType childState,
521
550
RealNumType distToRoot, RealNumType distToObserved,
522
551
RealNumType** waitingTimes, RealNumType** counts,
523
- RealNumType weight)
552
+ RealNumType* globalCounts, RealNumType weight)
524
553
{
525
554
if (parentState != childState) {
526
555
for (int i = start; i <= end; i++) {
@@ -530,6 +559,7 @@ void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType
530
559
waitingTimes[i][parentState] += weight * relativeRootIsStateParent * distToRoot/2 ;
531
560
waitingTimes[i][childState] += weight * relativeRootIsStateParent * distToRoot/2 ;
532
561
counts[i][childState + row_index[parentState]] += weight * relativeRootIsStateParent;
562
+ globalCounts[childState + row_index[parentState]] += weight * relativeRootIsStateParent;
533
563
534
564
RealNumType relativeRootIsStateChild = 1 - relativeRootIsStateParent;
535
565
waitingTimes[i][childState] += weight * relativeRootIsStateChild * distToRoot;
0 commit comments