File: AbbvGapsHmmExpectationEvaluator.java

package info (click to toggle)
libsecondstring-java 0.1~dfsg-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm, forky, sid, trixie
  • size: 764 kB
  • sloc: java: 9,592; xml: 114; makefile: 6
file content (75 lines) | stat: -rw-r--r-- 2,117 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
package com.wcohen.ss.abbvGapsHmm;

import java.util.List;

import com.wcohen.ss.abbvGapsHmm.AbbvGapsHMM.*;

/**
 * @author Dana Movshovitz-Attias
 */
public class AbbvGapsHmmExpectationEvaluator
		extends
			AbbvGapsHmmForwardEvaluator {
	
	protected List<Double> _transitionCounters;
	protected List<Double> _emissionCounters;
	
	protected Matrix3D _alpha;
	protected Matrix3D _beta;

	/**
	 */
	public AbbvGapsHmmExpectationEvaluator(AbbvGapsHMM abbvGapsHMM) {
		super(abbvGapsHMM);
	}
	
	public void expectationEvaluate(
			Acronym acronym, 
			List<Double> transitionCounters, List<Double> emissionCounters,
			List<Double> transitionParams, List<Double> emissionParams,
			Matrix3D alpha, Matrix3D beta){
		_transitionCounters = transitionCounters;
		_emissionCounters = emissionCounters;
		_transitionParams = transitionParams;
		_emissionParams = emissionParams;
		_alpha = alpha;
		_beta = beta;
		
		super.evaluate(acronym);
	}
	
	public List<Double> getTransitionCounters(){
		return _transitionCounters;
	}
	
	public List<Double> getEmissionCounters(){
		return _emissionCounters;
	}
	
	protected void updateLegalOutgoingEdges(
			int currS, int currL, States currState,
			int prevS, int prevL, States prevState,
			Transitions transition, Emissions emission
	){
		Double currProb = (	_alpha.at(prevS, prevL, prevState.ordinal())*
				_emissionParams.get(emission.ordinal())*
				_transitionParams.get(transition.ordinal())*
				_beta.at(currS, currL, currState.ordinal()) )
				/ _alpha.at(_alpha.dimension1()-1, _alpha.dimension2()-1, _alpha.dimension3()-1);
		increaseCounter(emission, currProb);
		increaseCounter(transition, currProb);
	}

	protected void increaseCounter(Emissions emission, double addition){
		double tmpCounter = _emissionCounters.get(emission.ordinal());
		tmpCounter += addition;
		_emissionCounters.set(emission.ordinal(), tmpCounter);
	}
	
	protected void increaseCounter(Transitions transition, double addition){
		double tmpCounter = _transitionCounters.get(transition.ordinal());
		tmpCounter += addition;
		_transitionCounters.set(transition.ordinal(), tmpCounter);
	}

}