Skip to content

Commit 1198ab4

Browse files
committed
Site-specific rate matrix updates
Add pseudocounts when calculating site-specific rates. User can specify with option --waiting-time-pseudocount
1 parent 6d07b03 commit 1198ab4

File tree

7 files changed

+78
-25
lines changed

7 files changed

+78
-25
lines changed

maple/cmaple.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ void cmaple::runCMAPLE(cmaple::Params &params)
123123
}
124124
assert(sub_model != cmaple::ModelBase::UNKNOWN);
125125
bool useRateVariationModel = params.rate_variation || params.site_specific_rates;
126-
Model model(aln.ref_seq.size(), useRateVariationModel, params.rate_variation, sub_model, aln.getSeqType());
126+
Model model(aln.ref_seq.size(), useRateVariationModel, params.rate_variation, params.wt_pseudocount, sub_model, aln.getSeqType());
127127

128128
// If users only want to convert the alignment to another format -> convert it and terminate
129129
if (params.output_aln.length())

model/model.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using namespace cmaple;
99
cmaple::Model::Model( cmaple::PositionType ref_genome_size,
1010
bool _rate_variation,
1111
bool _siteRates,
12+
cmaple::RealNumType wt_pseudocount,
1213
const cmaple::ModelBase::SubModel sub_model,
1314
const cmaple::SeqRegion::SeqType seqtype)
1415
: model_base(nullptr) {
@@ -65,7 +66,7 @@ cmaple::Model::Model( cmaple::PositionType ref_genome_size,
6566
}
6667
case cmaple::SeqRegion::SEQ_DNA: {
6768
if(rate_variation){
68-
model_base = new ModelDNARateVariation(n_sub_model, ref_genome_size, _siteRates);
69+
model_base = new ModelDNARateVariation(n_sub_model, ref_genome_size, _siteRates, wt_pseudocount);
6970
} else {
7071
model_base = new ModelDNA(n_sub_model);
7172
}

model/model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Model {
4949
cmaple::PositionType ref_genome_size,
5050
bool _rate_variation,
5151
bool _siteRates,
52+
cmaple::RealNumType wt_pseudocount,
5253
const cmaple::ModelBase::SubModel sub_model = cmaple::ModelBase::DEFAULT,
5354
const cmaple::SeqRegion::SeqType seqtype = cmaple::SeqRegion::SEQ_AUTO);
5455

model/model_dna_rate_variation.cpp

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44

55
using namespace cmaple;
66

7-
ModelDNARateVariation::ModelDNARateVariation(const cmaple::ModelBase::SubModel sub_model, PositionType _genomeSize, bool _useSiteRates)
7+
ModelDNARateVariation::ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genomeSize,
8+
bool _useSiteRates, cmaple::RealNumType _wtPseudocount)
89
: ModelDNA(sub_model) {
910

1011
genomeSize = _genomeSize;
1112
useSiteRates = _useSiteRates;
1213
matSize = row_index[num_states_];
14+
waitingTimePseudoCount = _wtPseudocount;
1315

1416
mutationMatrices = new RealNumType[matSize * genomeSize]();
1517
transposedMutationMatrices = new RealNumType[matSize * genomeSize]();
@@ -248,21 +250,26 @@ void ModelDNARateVariation::estimateRatePerSite(cmaple::Tree* tree){
248250

249251
void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
250252

251-
// Possibly better to keep this memory allocated and reuse
252-
// since this will be called several times?
253253
RealNumType** C = new RealNumType*[genomeSize];
254254
RealNumType** W = new RealNumType*[genomeSize];
255255
for(int i = 0; i < genomeSize; i++) {
256256
C[i] = new RealNumType[matSize];
257257
W[i] = new RealNumType[num_states_];
258258
for(int j = 0; j < num_states_; j++) {
259-
W[i][j] = 0.0001;
259+
W[i][j] = waitingTimePseudoCount;
260260
for(int k = 0; k < num_states_; k++) {
261261
C[i][row_index[j] + k] = 0;
262262
}
263263
}
264264
}
265265

266+
RealNumType* globalCounts = new RealNumType[matSize];
267+
for(int i = 0; i < num_states_; i++) {
268+
for(int j = 0; j < num_states_; j++) {
269+
globalCounts[row_index[i] + j] = 0;
270+
}
271+
}
272+
266273
std::stack<Index> nodeStack;
267274
const PhyloNode& root = tree->nodes[tree->root_vector_index];
268275
nodeStack.push(root.getNeighborIndex(RIGHT));
@@ -343,14 +350,15 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
343350
W[i][stateA] += branchLengthToObservation/2;
344351
W[i][stateB] += branchLengthToObservation/2;
345352
C[i][stateB + row_index[stateA]] += 1;
353+
globalCounts[stateB + row_index[stateA]] += 1;
346354
}
347355
} else {
348356
// Case 2: Last observation was the other side of the root.
349357
// In this case there are two further cases - the mutation happened either side of the root.
350358
// We calculate the relative likelihood of each case and use this to weight waiting times etc.
351359
RealNumType distToRoot = seqP_region->plength_observation2root + blength;
352360
RealNumType distToObserved = seqP_region->plength_observation2node;
353-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C);
361+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts);
354362
}
355363
} else if(seqP_region->type <= TYPE_R && seqC_region->type == TYPE_O) {
356364
StateType stateA = seqP_region->type;
@@ -363,6 +371,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
363371
RealNumType prob = seqC_region->getLH(stateB);
364372
if(stateB != stateA) {
365373
C[end_pos][stateB + row_index[stateA]] += prob;
374+
globalCounts[stateB + row_index[stateA]] += prob;
375+
366376
W[end_pos][stateA] += prob * branchLengthToObservation/2;
367377
W[end_pos][stateB] += prob * branchLengthToObservation/2;
368378
} else {
@@ -375,7 +385,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
375385
RealNumType distToObserved = seqP_region->plength_observation2node;
376386
for(StateType stateB = 0; stateB < num_states_; stateB++) {
377387
RealNumType prob = seqC_region->getLH(stateB);
378-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, prob);
388+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, prob);
379389
}
380390
}
381391
} else if(seqP_region->type == TYPE_O && seqC_region->type <= TYPE_R) {
@@ -389,6 +399,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
389399
RealNumType prob = seqP_region->getLH(stateA);
390400
if(stateB != stateA) {
391401
C[end_pos][stateB + row_index[stateA]] += prob;
402+
globalCounts[stateB + row_index[stateA]] += prob;
403+
392404
W[end_pos][stateA] += prob * branchLengthToObservation/2;
393405
W[end_pos][stateB] += prob * branchLengthToObservation/2;
394406
} else {
@@ -401,7 +413,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
401413
RealNumType distToObserved = seqP_region->plength_observation2node;
402414
for(StateType stateA = 0; stateA < num_states_; stateA++) {
403415
RealNumType prob = seqP_region->getLH(stateA);
404-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, prob);
416+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, prob);
405417
}
406418
}
407419
} else if(seqP_region->type == TYPE_O && seqC_region->type == TYPE_O) {
@@ -413,6 +425,8 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
413425
RealNumType probB = seqC_region->getLH(stateB);
414426
if(stateB != stateA) {
415427
C[end_pos][stateB + row_index[stateA]] += probA * probB;
428+
globalCounts[stateB + row_index[stateA]] += probA * probB;
429+
416430
W[end_pos][stateA] += probA * probB * branchLengthToObservation/2;
417431
W[end_pos][stateB] += probA * probB * branchLengthToObservation/2;
418432
} else {
@@ -428,7 +442,7 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
428442
RealNumType probA = seqP_region->getLH(stateA);
429443
for(StateType stateB = 0; stateB < num_states_; stateB++) {
430444
RealNumType probB = seqC_region->getLH(stateB);
431-
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, probA * probB);
445+
updateCountsAndWaitingTimesAcrossRoot(pos, end_pos, stateA, stateB, distToRoot, distToObserved, W, C, globalCounts, probA * probB);
432446
}
433447
}
434448
}
@@ -437,26 +451,40 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
437451
}
438452
}
439453

440-
RealNumType totalRate = 0;
454+
for(int j = 0; j < num_states_; j++) {
455+
for(int k = 0; k < num_states_; k++) {
456+
(globalCounts[row_index[j] + k] /= genomeSize);
457+
}
458+
}
441459

460+
printMatrix(globalCounts, &std::cout);
461+
// add pseudocount of average rate across genome * waitingTime pseudocount for counts
462+
for(int i = 0; i < genomeSize; i++) {
463+
for(int j = 0; j < num_states_; j++) {
464+
for(int k = 0; k < num_states_; k++) {
465+
C[i][row_index[j] + k] += globalCounts[row_index[j] + k] * waitingTimePseudoCount ;
466+
}
467+
}
468+
}
469+
470+
RealNumType totalRate = 0;
442471
// Update mutation matrices with new rate estimation
443472
for(int i = 0; i < genomeSize; i++) {
444473
RealNumType* Ci = C[i];
445474
RealNumType* Wi = W[i];
475+
StateType refState = tree->aln->ref_seq[static_cast<std::vector<cmaple::StateType>::size_type>(i)];
476+
446477
for(int stateA = 0; stateA < num_states_; stateA++) {
447478
for(int stateB = 0; stateB < num_states_; stateB++) {
448479
if(stateA != stateB) {
449-
RealNumType newRate;
450-
if(Ci[stateB + row_index[stateA]] == 0) {
451-
newRate = 0.001;
452-
} else if (Wi[stateA] <= 0.01) {
453-
newRate = 1.0;
454-
} else {
455-
newRate = Ci[stateB + row_index[stateA]] / Wi[stateA];
456-
newRate = MIN(100.0, MAX(0.0001, newRate));
457-
}
480+
RealNumType newRate = Ci[stateB + row_index[stateA]] / Wi[stateA];
481+
newRate = MIN(250.0, MAX(0.01, newRate));
458482
mutationMatrices[i * matSize + (stateB + row_index[stateA])] = newRate;
459-
totalRate += newRate;
483+
484+
// Approximate total rate by considering rates from reference nucleotide
485+
if(refState == stateA) {
486+
totalRate += newRate;
487+
}
460488
}
461489
}
462490
}
@@ -513,14 +541,15 @@ void ModelDNARateVariation::estimateRatesPerSitePerEntry(cmaple::Tree* tree) {
513541
}
514542
delete[] C;
515543
delete[] W;
544+
delete[] globalCounts;
516545

517546
}
518547

519548
void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType start, PositionType end,
520549
StateType parentState, StateType childState,
521550
RealNumType distToRoot, RealNumType distToObserved,
522551
RealNumType** waitingTimes, RealNumType** counts,
523-
RealNumType weight)
552+
RealNumType* globalCounts, RealNumType weight)
524553
{
525554
if(parentState != childState) {
526555
for(int i = start; i <= end; i++) {
@@ -530,6 +559,7 @@ void ModelDNARateVariation::updateCountsAndWaitingTimesAcrossRoot( PositionType
530559
waitingTimes[i][parentState] += weight * relativeRootIsStateParent * distToRoot/2;
531560
waitingTimes[i][childState] += weight * relativeRootIsStateParent * distToRoot/2;
532561
counts[i][childState + row_index[parentState]] += weight * relativeRootIsStateParent;
562+
globalCounts[childState + row_index[parentState]] += weight * relativeRootIsStateParent;
533563

534564
RealNumType relativeRootIsStateChild = 1 - relativeRootIsStateParent;
535565
waitingTimes[i][childState] += weight * relativeRootIsStateChild * distToRoot;

model/model_dna_rate_variation.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class Tree;
1212
/** Class of DNA evolutionary models with rate variation */
1313
class ModelDNARateVariation : public ModelDNA {
1414
public:
15-
ModelDNARateVariation(const cmaple::ModelBase::SubModel sub_model, PositionType _genomeSize, bool _useSiteRates);
15+
ModelDNARateVariation( const cmaple::ModelBase::SubModel sub_model, PositionType _genomeSize,
16+
bool _useSiteRates, cmaple::RealNumType _wtPseudocount);
1617
virtual ~ModelDNARateVariation();
1718

1819
void estimateRates(cmaple::Tree* tree);
@@ -26,7 +27,6 @@ class ModelDNARateVariation : public ModelDNA {
2627
};
2728

2829
virtual inline const cmaple::RealNumType *const getMutationMatrixRow(StateType row, PositionType i) const override {
29-
assert(i < genomeSize);
3030
return mutationMatrices + (i * matSize) + row_index[row];
3131
};
3232

@@ -83,7 +83,7 @@ class ModelDNARateVariation : public ModelDNA {
8383
StateType parentState, StateType childState,
8484
RealNumType distToRoot, RealNumType distToObserved,
8585
RealNumType** waitingTimes, RealNumType** counts,
86-
RealNumType weight = 1.);
86+
RealNumType* globalCounts, RealNumType weight = 1.);
8787

8888
cmaple::PositionType genomeSize;
8989

@@ -97,5 +97,7 @@ class ModelDNARateVariation : public ModelDNA {
9797
bool useSiteRates = false;
9898
bool ratesEstimated = false;
9999

100+
cmaple::RealNumType waitingTimePseudoCount;
101+
100102
};
101103
}

utils/tools.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ cmaple::Params::Params() {
623623
thresh_loglh_optimal_diff_fac = 1.0;
624624
rate_variation = false;
625625
site_specific_rates = false;
626+
wt_pseudocount = 0.1;
626627

627628
// initialize random seed based on current time
628629
struct timeval tv;
@@ -1278,6 +1279,19 @@ void cmaple::parseArg(int argc, char* argv[], Params& params) {
12781279
continue;
12791280
}
12801281

1282+
if (strcmp(argv[cnt], "--waiting-time-pseudocount") == 0) {
1283+
cnt++;
1284+
if (cnt >= argc) {
1285+
outError("Use --waiting-time-pseudocount <pseudocount>");
1286+
}
1287+
try {
1288+
params.wt_pseudocount = convert_real_number(argv[cnt]);
1289+
} catch (std::invalid_argument e) {
1290+
outError(e.what());
1291+
}
1292+
continue;
1293+
}
1294+
12811295
// return invalid option
12821296
string err = "Invalid \"";
12831297
err += argv[cnt];

utils/tools.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,11 @@ class Params {
652652
*/
653653
bool rate_variation;
654654

655+
/**
656+
* Pseudocount used for waiting times when estimating site-specific rate matrices.
657+
*/
658+
RealNumType wt_pseudocount;
659+
655660
/**
656661
* TRUE to allow an independent rate matrix for each genomic site.
657662
*/

0 commit comments

Comments
 (0)