File: bblEM2USSRV.cpp

package info (click to toggle)
fastml 3.11-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,772 kB
  • sloc: cpp: 48,522; perl: 3,588; ansic: 819; makefile: 386; python: 83; sh: 55
file content (181 lines) | stat: -rwxr-xr-x 5,898 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
// 	$Id: bblEM2USSRV.cpp 1944 2007-04-18 12:41:14Z osnatz $	
#include "bblEM2USSRV.h"

bblEM2USSRV::bblEM2USSRV(tree& et,
				const sequenceContainer& sc,
				const sequenceContainer& baseSc,
				const ussrvModel& model,
				const Vdouble * weights,
				int maxIterations,
				MDOUBLE epsilon,
				MDOUBLE tollForPairwiseDist) :
_et(et),_sc(sc),_baseSc(baseSc),_model(model),_weights (weights)
{	
	LOG(5,<<"******BBL EM USSRV*********"<<endl<<endl);
	_treeLikelihood = compute_bblEM(maxIterations,epsilon,tollForPairwiseDist);
}

// @@@@ Need to check if we can make it more efficient
MDOUBLE bblEM2USSRV::compute_bblEM(
			int maxIterations,
			MDOUBLE epsilon,
			MDOUBLE tollForPairwiseDist){
	
	allocatePlace();
	MDOUBLE oldL = VERYSMALL;
	MDOUBLE currL = VERYSMALL;
	tree oldT = _et;
	for (int i=0; i < maxIterations; ++i) {
		computeUp();
		// Calculate the likelihood and fill the _posLike
		currL = likelihoodComputation2USSRV::getTreeLikelihoodFromUp2(_et,
				_sc,_baseSc,_model,_cupBase,_cupSSRV,_posLike,_weights);
		//////////////
		LOGDO(5,printTime(myLog::LogFile()));
		LOG(5,<<"iteration no "<<i << " in BBL "<<endl);
		LOG(5,<<"old best  L= "<<oldL<<endl);
		LOG(5,<<"current best  L= "<<currL<<endl);
	

		if (currL < oldL + epsilon) { // need to break
			if (currL<oldL) {
				cout<<"******** PROBLEMS IN BBL USSRV*********"<<endl;
				LOG(5,<<"old best  L= "<<oldL<<endl);
				LOG(5,<<"current best  L= "<<currL<<endl);
				_et = oldT;
				return oldL; // keep the old tree, and old likelihood
			} else {
                //update the tree and likelihood and return
				LOG(5,<<"old best  L= "<<oldL<<endl);
				LOG(5,<<"current best  L= "<<currL<<endl);
				return currL;
			}
		}
		oldT = _et;
		bblEM_it(tollForPairwiseDist);
		oldL = currL;
	}
	// in the case were we reached max_iter, we have to recompute the likelihood of the new tree...
	computeUp();
	currL = likelihoodComputation2USSRV::getTreeLikelihoodFromUp2(_et,
			_sc,_baseSc,_model,_cupBase,_cupSSRV,_posLike,_weights);
	if (currL<oldL) {
		_et = oldT;
		return oldL; // keep the old tree, and old likelihood
	} 
	else 
        return currL;
}


void bblEM2USSRV::allocatePlace() {
	_computeCountsBaseV.resize(_et.getNodesNum()); //initiateTablesOfCounts
	_computeCountsSsrvV.resize(_et.getNodesNum()); //initiateTablesOfCounts
	
	for (int i=0; i < _computeCountsBaseV.size(); ++i) {
		_computeCountsBaseV[i].countTableComponentAllocatePlace(_model.getBaseModel().alphabetSize(),_model.noOfCategor());
		_computeCountsSsrvV[i].countTableComponentAllocatePlace(_model.getSSRVmodel().alphabetSize());
	}
	_cupBase.allocatePlace(_baseSc.seqLen(),_model.noOfCategor(), _et.getNodesNum(), _baseSc.alphabetSize());
	_cupSSRV.allocatePlace(_sc.seqLen(), _et.getNodesNum(), _sc.alphabetSize());

	_cdownBase.allocatePlace(_model.noOfCategor(), _et.getNodesNum(), _baseSc.alphabetSize());
	_cdownSSRV.allocatePlace( _et.getNodesNum(), _sc.alphabetSize());

}

void bblEM2USSRV::bblEM_it(MDOUBLE tollForPairwiseDist){
	for (int i=0; i < _computeCountsBaseV.size(); ++i) {
		_computeCountsBaseV[i].zero();
		_computeCountsSsrvV[i].zero();
	}
	for (int i=0; i < _sc.seqLen(); ++i) {
		computeDown(i);
		addCounts(i); // computes the counts and adds to the table.
	}
	optimizeBranches(tollForPairwiseDist);
}

// @@@@ need to print the tree
void bblEM2USSRV::optimizeBranches(MDOUBLE tollForPairwiseDist){
	treeIterDownTopConst tIt(_et);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (!tIt->isRoot()) {
			fromCountTableComponentToDistance2USSRV 
				from1(_computeCountsBaseV[mynode->id()],_computeCountsSsrvV[mynode->id()],_model,tollForPairwiseDist,mynode->dis2father());
			from1.computeDistance();
			mynode->setDisToFather(from1.getDistance());
		}
	}
}

void bblEM2USSRV::computeUp(){
	_pijBase.fillPij(_et,_model.getBaseModel(),0); // 0 is becaues we compute Pij(t) and not its derivations...
	_pijSSRV.fillPij(_et,_model.getSSRVmodel(),0);
	
	computeUpAlg cupAlg;
	for (int pos=0; pos < _sc.seqLen(); ++pos) {
		// compute up for the base model
		for (int categor = 0; categor < _model.noOfCategor(); ++categor) {
			cupAlg.fillComputeUp(_et,_baseSc,pos,_pijBase[categor],_cupBase[pos][categor]);
		}
		// compute up for the ssrv model
		cupAlg.fillComputeUp(_et,_sc,pos,_pijSSRV,_cupSSRV[pos]);
	}
}

void bblEM2USSRV::computeDown(int pos){
	computeDownAlg cdownAlg;
	// compute down for the base model
	for (int categor = 0; categor < _model.noOfCategor(); ++categor) {
		cdownAlg.fillComputeDown(_et,_baseSc,pos,_pijBase[categor],_cdownBase[categor],_cupBase[pos][categor]);		
	}
	// compute down for the ssrv model
	cdownAlg.fillComputeDown(_et,_sc,pos,_pijSSRV,_cdownSSRV,_cupSSRV[pos]);		
}

void bblEM2USSRV::addCounts(int pos){
						
	MDOUBLE weig = (_weights ? (*_weights)[pos] : 1.0);
	if (weig == 0) return;
	treeIterDownTopConst tIt(_et);
	for (tree::nodeP mynode = tIt.first(); mynode != tIt.end(); mynode = tIt.next()) {
		if (!tIt->isRoot()) {
			addCounts(pos,mynode,_posLike[pos],weig);
		}
	}
}

void bblEM2USSRV::addCounts(int pos, tree::nodeP mynode, doubleRep posProb, MDOUBLE weig){

	computeCounts cc;
	int categor;
	// base Model
	for (categor =0; categor< _model.noOfCategor(); ++categor) {
			cc.computeCountsNodeFatherNodeSonHomPos(_baseSc, 
										_pijBase[categor],
										_model.getBaseModel(),
										_cupBase[pos][categor],
										_cdownBase[categor],
										weig,
										posProb,
										mynode,
										_computeCountsBaseV[mynode->id()][categor],
										_model.getCategorProb(categor)*(1-_model.getF()));
	
	}
	// SSRV model
	cc.computeCountsNodeFatherNodeSonHomPos(_sc, 
										_pijSSRV,
										_model.getSSRVmodel(),
										_cupSSRV[pos],
										_cdownSSRV,
										weig,
										posProb,
										mynode,
										_computeCountsSsrvV[mynode->id()],
										_model.getF());
}