Skip to content

Commit d7eb710

Browse files
committed
optimise mAIC computation for mixed data (sum lnL, compute total mAIC), for nonrev model (use artificial taxon strategy).
1 parent c54394d commit d7eb710

File tree

3 files changed

+31
-192
lines changed

3 files changed

+31
-192
lines changed

main/phyloanalysis.cpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -925,30 +925,16 @@ void reportTree(ofstream &out, Params &params, PhyloTree &tree, double tree_lh,
925925
// mAIC report
926926
if (tree.isSuperTree() && params.partition_type != TOPO_UNLINKED) {
927927
// compute mAIC/mBIC/mAICc if it is a partition model
928-
int ntrees; //mix_df;
929-
double mix_lh;
928+
double mix_lh = tree.getModelFactory()->computeMarginalLh(params.remove_empty_seq);
930929

931-
mix_lh = tree.getModelFactory()->computeMarginalLh(params.remove_empty_seq);
932-
if (mix_lh < 0) {
933-
PhyloSuperTree *stree = (PhyloSuperTree*) &tree;
934-
ntrees = stree->size();
935-
//mix_df = df + ntrees - 1; // Ed Susko: The weights are fixed by the partition length, so there are no extra degrees of freedom
936-
//nsites = tree.getAlnNSite();
937-
938-
double mAIC, mAICc, mBIC;
939-
computeInformationScores(mix_lh, df, ssize, mAIC, mAICc, mBIC);
930+
double mAIC, mAICc, mBIC;
931+
computeInformationScores(mix_lh, df, ssize, mAIC, mAICc, mBIC);
940932

941-
out << endl;
942-
out << "Marginal log-likelihood of the tree: " << mix_lh << endl;
943-
out << "Marginal Akaike information criterion (mAIC) score: " << mAIC << endl;
944-
//out << "Marginal corrected Akaike information criterion (mAICc) score: " << mAICc << endl;
945-
//out << "Marginal Bayesian information criterion (mBIC) score: " << mBIC << endl;
946-
} else {
947-
// mixed data types: compute mAIC per data type group and sum
948-
double mAIC = ((PartitionModel*)tree.getModelFactory())->computeMarginalAIC(params.remove_empty_seq);
949-
out << endl;
950-
out << "Marginal Akaike information criterion (mAIC) score: " << mAIC << " (computed per data type)" << endl;
951-
}
933+
out << endl;
934+
out << "Marginal log-likelihood of the tree: " << mix_lh << endl;
935+
out << "Marginal Akaike information criterion (mAIC) score: " << mAIC << endl;
936+
//out << "Marginal corrected Akaike information criterion (mAICc) score: " << mAICc << endl;
937+
//out << "Marginal Bayesian information criterion (mBIC) score: " << mBIC << endl;
952938
}
953939

954940
if (ssize <= df && main_tree) {

model/partitionmodel.cpp

Lines changed: 19 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -337,25 +337,6 @@ double PartitionModel::targetFunk(double x[]) {
337337
return res;
338338
}
339339

340-
/**
341-
compute the distance from root to a named leaf via DFS
342-
@param node current node
343-
@param dad parent node (NULL for root)
344-
@param target leaf name to find
345-
@return distance from node to the target leaf, or -1.0 if not found
346-
*/
347-
static double rootToLeafDist(Node *node, Node *dad, const string &target) {
348-
if (node->isLeaf() && node->name == target)
349-
return 0.0;
350-
for (auto nei : node->neighbors) {
351-
if (nei->node != dad) {
352-
double d = rootToLeafDist(nei->node, node, target);
353-
if (d >= 0)
354-
return d + nei->length;
355-
}
356-
}
357-
return -1.0;
358-
}
359340

360341
double PartitionModel::computeMarginalLhForPartitions(vector<int> &part_indices, bool remove_empty_seq) {
361342
PhyloSuperTree *tree = (PhyloSuperTree*)site_rate->getTree();
@@ -455,7 +436,8 @@ double PartitionModel::computeMarginalLhForPartitions(vector<int> &part_indices,
455436
log_state_freq[n] = log(state_freq[n]);
456437
}
457438

458-
if (inter_seqs_id.size() > 1) {
439+
if (inter_seqs_id.size() > 1 ||
440+
(inter_seqs_id.size() == 1 && !tree2->getModel()->isReversible())) {
459441
// subset tree1_aln
460442
Alignment *sub_tree1_aln = nullptr;
461443
if (tree1_seqs.size() != inter_seqs.size() || (!remove_empty_seq && tree1_seqs.size() < ntaxa)) {
@@ -499,23 +481,22 @@ double PartitionModel::computeMarginalLhForPartitions(vector<int> &part_indices,
499481
sub_tree2->copyTree(tree2);
500482
}
501483

502-
if (inter_seqs_id.size() == 2 && tree2->getModel()->isReversible()) {
503-
// if only two seqs in the subset
504-
// add a gappy seq two the sub_aln
484+
if ((inter_seqs_id.size() == 2 && tree2->getModel()->isReversible()) ||
485+
(inter_seqs_id.size() == 1 && !tree2->getModel()->isReversible())) {
486+
// too few taxa for likelihood kernel; add a gappy taxon with all-unknown states
505487
string gappy_seq = "gappy_seq";
506488
sub_tree1_aln->addSeqName(gappy_seq);
507489
for (size_t patt = 0; patt < sub_tree1_aln->size(); ++patt) {
508490
sub_tree1_aln->at(patt).push_back(sub_tree1_aln->STATE_UNKNOWN);
509491
}
510492

511-
// add a taxa with 0 branch length to the sub_tree
493+
// add a gappy taxon with 0 branch length to the sub_tree
512494
Node *inter_node = sub_tree2->newNode();
513495
Node *gappy_taxon = sub_tree2->newNode(-1, "gappy_seq");
514-
auto it = inter_seqs.begin();
515-
string inter_seq1 = *it;
516-
Node *node1 = sub_tree2->findLeafName(inter_seq1);
496+
string first_seq = *inter_seqs.begin();
497+
Node *node1 = sub_tree2->findLeafName(first_seq);
517498
Node *node2 = node1->neighbors[0]->node;
518-
double half_branch = node1->findNeighbor(node2)->length/2.0;
499+
double half_branch = node1->neighbors[0]->length / 2.0;
519500

520501
node1->updateNeighbor(node2, inter_node, half_branch);
521502
node2->updateNeighbor(node1, inter_node, half_branch);
@@ -524,9 +505,9 @@ double PartitionModel::computeMarginalLhForPartitions(vector<int> &part_indices,
524505

525506
inter_node->addNeighbor(gappy_taxon, 0);
526507
gappy_taxon->addNeighbor(inter_node, 0);
527-
sub_tree2->branchNum+= 2;
508+
sub_tree2->branchNum += 2;
528509
sub_tree2->leafNum++;
529-
sub_tree2->nodeNum = 2 * sub_tree2->leafNum -2;
510+
sub_tree2->nodeNum = 2 * sub_tree2->leafNum - 2;
530511
}
531512

532513
// link sub_tree2 and sub_tree1_aln
@@ -599,9 +580,9 @@ double PartitionModel::computeMarginalLhForPartitions(vector<int> &part_indices,
599580
sub_tree2->aln = nullptr;
600581
delete sub_tree2;
601582
delete[] ptn_lh_array;
602-
} else if (tree2->getModel()->isReversible() || inter_seqs.size() == 0) {
603-
// case when the intersection of taxon sets is 1 and reversible model
604-
// case when the intersection of taxon sets is 0
583+
} else {
584+
// case when the intersection of taxon sets is 0 or 1 (reversible model)
585+
// all taxa use state frequencies only
605586
for (int l = 0; l < tree1_nsite; l++) {
606587
double site_lh = 0.0;
607588
Pattern p = tree1_aln->at(tree1_aln->getPatternID(l));
@@ -635,89 +616,6 @@ double PartitionModel::computeMarginalLhForPartitions(vector<int> &part_indices,
635616
}
636617
lh_array[tree1_nsite * k + l] = site_lh;
637618
}
638-
} else {
639-
// intersection has only 1 taxon and non-reversible model
640-
// compute site likelihood from root to the tip:
641-
// L(x) = sum_r pi(r) * P(r -> x | t_root_to_tip)
642-
string inter_seq_name = *inter_seqs.begin();
643-
double dist = rootToLeafDist(tree2->root, NULL, inter_seq_name);
644-
ASSERT(dist >= 0);
645-
646-
// compute transition matrix for root-to-tip distance
647-
double *trans_matrix = new double[n_states * n_states];
648-
tree2->getModel()->computeTransMatrix(dist, trans_matrix);
649-
650-
// precompute log of root-to-tip probabilities: log(sum_r pi(r) * P(r -> x | t))
651-
vector<double> log_root_to_tip_prob(n_states);
652-
for (int x = 0; x < n_states; x++) {
653-
double prob = 0.0;
654-
for (int r = 0; r < n_states; r++) {
655-
prob += state_freq[r] * trans_matrix[r * n_states + x];
656-
}
657-
log_root_to_tip_prob[x] = log(prob);
658-
}
659-
660-
for (int l = 0; l < tree1_nsite; l++) {
661-
double site_lh = 0.0;
662-
Pattern p = tree1_aln->at(tree1_aln->getPatternID(l));
663-
664-
for (string seq_name : tree1_seqs) {
665-
int missing_id = tree1_aln->getSeqID(seq_name);
666-
int char_id = p[missing_id];
667-
if (inter_seqs.find(seq_name) != inter_seqs.end()) {
668-
// intersecting taxon: use root-to-tip probability
669-
if (char_id < n_states) {
670-
site_lh += log_root_to_tip_prob[char_id];
671-
} else {
672-
// ambiguous state: sum over possible states
673-
if (seqtype == SEQ_DNA) {
674-
int cstate = char_id - n_states + 1;
675-
double amb_prob = 0;
676-
for (int m = 0; m < n_states; m++) {
677-
if ((cstate) & (1 << m)) {
678-
amb_prob += exp(log_root_to_tip_prob[m]);
679-
}
680-
}
681-
site_lh += log(amb_prob);
682-
} else if (seqtype == SEQ_PROTEIN) {
683-
if (char_id < 23) {
684-
int cstate = char_id - n_states;
685-
double amb_prob = 0;
686-
amb_prob += exp(log_root_to_tip_prob[ambi_aa[cstate*2]]);
687-
amb_prob += exp(log_root_to_tip_prob[ambi_aa[cstate*2+1]]);
688-
site_lh += log(amb_prob);
689-
}
690-
}
691-
}
692-
} else {
693-
// missing taxon: use state frequencies (same as reversible case)
694-
if (char_id < n_states) {
695-
site_lh += log_state_freq[char_id];
696-
} else {
697-
if (seqtype == SEQ_DNA) {
698-
int cstate = char_id - n_states + 1;
699-
double amb_freq = 0;
700-
for (int m = 0; m < n_states; m++) {
701-
if ((cstate) & (1 << m)) {
702-
amb_freq += state_freq[m];
703-
}
704-
}
705-
site_lh += log(amb_freq);
706-
} else if (seqtype == SEQ_PROTEIN) {
707-
if (char_id < 23) {
708-
int cstate = char_id - n_states;
709-
double amb_freq = 0;
710-
amb_freq += state_freq[ambi_aa[cstate*2]];
711-
amb_freq += state_freq[ambi_aa[cstate*2+1]];
712-
site_lh += log(amb_freq);
713-
}
714-
}
715-
}
716-
}
717-
}
718-
lh_array[tree1_nsite * k + l] = site_lh;
719-
}
720-
delete[] trans_matrix;
721619
}
722620
delete[] state_freq;
723621
}
@@ -754,59 +652,22 @@ double PartitionModel::computeMarginalLh(bool remove_empty_seq) {
754652
PhyloSuperTree *tree = (PhyloSuperTree*)site_rate->getTree();
755653
int ntrees = tree->size();
756654

757-
// all partition sequence type should be same, either DNA or protein or other
758-
SeqType seqtype = tree->at(0)->aln->seq_type;
759-
for (int j = 1; j < ntrees; j++) {
760-
if (tree->at(j)->aln->seq_type != seqtype) {
761-
return 1.0;
762-
}
763-
}
764-
765-
// all partitions have the same type, use them all
766-
vector<int> all_indices(ntrees);
767-
for (int j = 0; j < ntrees; j++) {
768-
all_indices[j] = j;
769-
}
770-
return computeMarginalLhForPartitions(all_indices, remove_empty_seq);
771-
}
772-
773-
double PartitionModel::computeMarginalAIC(bool remove_empty_seq) {
774-
PhyloSuperTree *tree = (PhyloSuperTree*)site_rate->getTree();
775-
int ntrees = tree->size();
776-
777655
// group partition indices by sequence type
778656
map<SeqType, vector<int>> seqtype_groups;
779657
for (int j = 0; j < ntrees; j++) {
780658
SeqType st = tree->at(j)->aln->seq_type;
781659
seqtype_groups[st].push_back(j);
782660
}
783661

784-
double total_maic = 0.0;
785-
662+
// compute marginal log-likelihood per data type group and sum
663+
// weights within each group are partition length / total length of that group
664+
double total_marginal_lh = 0.0;
786665
for (auto &group_pair : seqtype_groups) {
787666
vector<int> &group_indices = group_pair.second;
788-
789-
// compute marginal log-likelihood for this data type group
790-
double group_marginal_lh = computeMarginalLhForPartitions(group_indices, remove_empty_seq);
791-
792-
// compute df for this group: sum of per-partition free parameters
793-
int group_df = 0;
794-
for (int idx : group_indices) {
795-
group_df += tree->at(idx)->getModelFactory()->getNParameters(BRLEN_OPTIMIZE);
796-
}
797-
798-
// compute ssize for this group: sum of per-partition site counts
799-
/*int group_ssize = 0;
800-
for (int idx : group_indices) {
801-
group_ssize += tree->at(idx)->getAlnNSite();
802-
}*/
803-
804-
// mAIC = -2 * marginal_lh + 2 * df
805-
double group_maic = -2.0 * group_marginal_lh + 2.0 * group_df;
806-
total_maic += group_maic;
667+
total_marginal_lh += computeMarginalLhForPartitions(group_indices, remove_empty_seq);
807668
}
808669

809-
return total_maic;
670+
return total_marginal_lh;
810671
}
811672

812673
void PartitionModel::setVariables(double *variables) {

model/partitionmodel.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,11 @@ class PartitionModel : public ModelFactory
127127

128128
/**
129129
compute the marginal log-likelihood for mAIC, mAICc, mBIC calculation.
130-
Only works when all partitions have the same sequence type.
130+
Groups partitions by sequence type and computes marginal log-likelihood
131+
per group (weights = partition length / total length within that group),
132+
then sums across all data type groups.
131133
@param remove_empty_seq whether remove empty sequences when partition model estimation
132-
@return marginal log-likelihood, or 1.0 if partitions have mixed data types
134+
@return marginal log-likelihood summed across data type groups
133135
*/
134136
virtual double computeMarginalLh(bool remove_empty_seq);
135137

@@ -142,16 +144,6 @@ class PartitionModel : public ModelFactory
142144
*/
143145
double computeMarginalLhForPartitions(vector<int> &part_indices, bool remove_empty_seq);
144146

145-
/**
146-
compute the sum of mAIC scores across data type groups.
147-
Groups partitions by sequence type, computes marginal log-likelihood
148-
and mAIC for each group independently, and returns the sum.
149-
Works for both single and mixed data type partition models.
150-
@param remove_empty_seq whether remove empty sequences when partition model estimation
151-
@return sum of mAIC scores across all data type groups
152-
*/
153-
double computeMarginalAIC(bool remove_empty_seq);
154-
155147
/**
156148
rescale the state frequencies
157149
@param sum_one TRUE to make frequencies sum to 1, FALSE to make last entry equal to 1

0 commit comments

Comments
 (0)