/**
 * Title:        ProAlign<p>
 * Description:  <p>
 * Copyright:    Copyright (c) Ari Loytynoja<p>
 * License:      GNU GENERAL PUBLIC LICENSE<p>
 * @see          http://www.gnu.org/copyleft/gpl.html
 * Company:      ULB<p>
 * @author Ari Loytynoja
 * @version 1.0
 */
package proalign;

class Viterbi {

    AlignmentLoop al;

    double[] charFreqs;
    String alphabet;

    double delta;
    double epsilon;
    double logDelta;
    double logEpsilon;
    double logMinus2Delta;
    double logMinusEpsilon;
    
    double constantM;
    double constantX;
    double constantY;

    TransformLog tl;

    int aSize;

    Viterbi(AlignmentLoop al) {
	this.al = al;	

	this.charFreqs = al.pa.sm.charFreqs;
	this.alphabet = al.pa.sm.alphabet;

	delta = al.pa.sm.delta;
	epsilon = al.pa.sm.epsilon;
	logDelta = Math.log(al.pa.sm.delta);
	logEpsilon = Math.log(al.pa.sm.epsilon);
	logMinus2Delta = Math.log(1-2*al.pa.sm.delta);
	logMinusEpsilon = Math.log(1-al.pa.sm.epsilon);

	aSize = al.seq1[0].length; // aSize = alphabet size; 

	tl = new TransformLog();
    }

    // get viterbi value for M-state; set probability of each path to this cell;
    // pathM[i][j][0] = P(match to M)
    // pathM[i][j][1] = P(match to X)
    // 1 - pathM[i][j][0] - pathM[i][j][1] = P(match to Y)

    double getViterbiM(int i, int j) {

	constantM = Math.log(sumOverChars(i,j));

	double max = maxOfThree(logMinus2Delta+al.vitM[i-1][j],
				logMinusEpsilon+al.vitX[i-1][j],
				logMinusEpsilon+al.vitY[i-1][j]);

	double pm = logMinus2Delta+al.vitM[i-1][j];
	double px = logMinusEpsilon+al.vitX[i-1][j];
	double py = logMinusEpsilon+al.vitY[i-1][j];
	double sum = tl.sumLogs(pm,tl.sumLogs(px,py));

	al.pathM[i][j][0] = (double) Math.exp(pm-sum);
	al.pathM[i][j][1] = (double) Math.exp(px-sum);

	return (constantM+max);
    }

    // get viterbi value for X-state; set probability of each path to this cell;
    // pathX[i][j] = P(match to M)
    // 1 - pathX[i][j] = P(match to X)

    double getViterbiX(int i, int j) {
	    
	if(j==al.BWIDTH-1) {  // lower border of band, an X-gap impossible
	    return Double.NEGATIVE_INFINITY;
	} 

	constantX = Math.log(sumOverCharsX(i,j));

	double max = maxOfTwoX(logDelta+al.vitM[i-1][j+1],
			       logEpsilon+al.vitX[i-1][j+1]);
	
	double pm = logDelta+al.vitM[i-1][j+1];
	double px = logEpsilon+al.vitX[i-1][j+1];
	double sum = tl.sumLogs(px,pm);
	
	al.pathX[i][j] = (double) Math.exp(pm-sum);
    
	return constantX+max;
    }

    // get viterbi value for X-state; set probability of each path to this cell;
    // pathY[i][j] = P(match to M)
    // 1 - pathY[i][j] = P(match to Y)

    double getViterbiY(int i, int j) {

	if(j==0) {   // upper border of band, a Y-gap impossible
	    return Double.NEGATIVE_INFINITY;
	}

	constantY = Math.log(sumOverCharsY(i,j));

	double max = maxOfTwoY(logDelta+al.vitM[i][j-1],
			       logEpsilon+al.vitY[i][j-1]);

	double pm = logDelta+al.vitM[i][j-1];
	double py = logEpsilon+al.vitY[i][j-1];
	double sum = tl.sumLogs(py,pm);

	al.pathY[i][j] = (double) Math.exp(pm-sum);

	return constantY+max;
    }

    // set vierbi end, set forward end,
    // set probability of each path ending.
    // pathEnd[0] = P(match to M), pathEnd[1] = P(gap to x)
    // 1 - pathEnd[0] - pathEnd[1] = P(gap to y)
    // endPoint is not necessarily MIDDLE!
    
    double getViterbiEnd(int endPoint) {
	
	if(al.an.hasTrailers) {

	    double pm = al.vitM[al.vitM.length-1][endPoint];
	    double px = al.vitX[al.vitX.length-1][endPoint];
	    double py = al.vitY[al.vitY.length-1][endPoint];
	    double sum;

	    if(al.an.end0>al.an.end1) {
		sum =  tl.sumLogs(pm,px);
		al.pathEnd[0] = (double) Math.exp(pm-sum);
		al.pathEnd[1] = (double) Math.exp(px-sum);

		al.fwdEnd = (double) tl.sumLogs(al.fwdM[al.fwdM.length-1][endPoint],
						al.fwdX[al.fwdX.length-1][endPoint]);
	    } else {
		sum =  tl.sumLogs(pm,py);
		al.pathEnd[0] = (double) Math.exp(pm-sum);
		al.pathEnd[1] = Double.NEGATIVE_INFINITY;

		al.fwdEnd = (double) tl.sumLogs(al.fwdM[al.fwdM.length-1][endPoint],
						al.fwdY[al.fwdY.length-1][endPoint]);
	    }

	} else {

	    double pm = al.vitM[al.vitM.length-1][endPoint];
	    double px = al.vitX[al.vitX.length-1][endPoint];
	    double py = al.vitY[al.vitY.length-1][endPoint];
	    double sum =  tl.sumLogs(pm,tl.sumLogs(px,py));
	    
	    al.pathEnd[0] = (double) Math.exp(pm-sum);
	    al.pathEnd[1] = (double) Math.exp(px-sum);
		    
	    al.fwdEnd = (double) tl.sumLogs(al.fwdM[al.fwdM.length-1][endPoint],
					    tl.sumLogs(al.fwdX[al.fwdX.length-1][endPoint],
						       al.fwdY[al.fwdY.length-1][endPoint]));
	}    

	return (double) maxOfThree(al.vitM[al.vitM.length-1][endPoint],
				   al.vitX[al.vitX.length-1][endPoint],
				   al.vitY[al.vitY.length-1][endPoint]);
    }

    double getForwardM(int i, int j) {

	return constantM+tl.sumLogs(logMinus2Delta+al.fwdM[i-1][j],
				    logMinusEpsilon+tl.sumLogs(al.fwdX[i-1][j],al.fwdY[i-1][j]));
    }

    double getForwardX(int i, int j) {

	if(j==al.BWIDTH-1) { // lower border of band, an X-gap impossible
	    return Double.NEGATIVE_INFINITY;
	}

	if(j-al.MIDDLE+i==1) { // only X-gaps possible, constant's not set
	    constantX = Math.log(sumOverCharsX(i,j));
	}
			   
	return constantX+tl.sumLogs(logDelta+al.fwdM[i-1][j+1],logEpsilon+al.fwdX[i-1][j+1]);
    }

    double getForwardY(int i, int j) {

	if(j==0) { // upper border of band, a Y-gap impossible
	    return Double.NEGATIVE_INFINITY;
	}

	if(i==1) { // only Y-gaps possible, constant's not set
	    constantY = Math.log(sumOverCharsY(i,j));
	}

	return constantY+tl.sumLogs(logDelta+al.fwdM[i][j-1],logEpsilon+al.fwdY[i][j-1]);
    }

    double getBackwardM(int i, int j) {

	if(i==al.vitM.length-1) { // left border of matrix

	    constantM = Double.NEGATIVE_INFINITY;
	    constantX = Double.NEGATIVE_INFINITY;
	    constantY = Math.log(sumOverCharsY(i,j+1));

	}else if(j-al.MIDDLE+i==al.seq2.length+1) { // lower border of real matrix

	    constantM = Double.NEGATIVE_INFINITY;	    
	    constantX = Math.log(sumOverCharsX(i+1,j-1));
	    constantY = Double.NEGATIVE_INFINITY;

	}else if(j==0) {

	    constantM = Math.log(sumOverChars(i+1,j));
	    constantX = Double.NEGATIVE_INFINITY;
	    constantY = Math.log(sumOverCharsY(i,j+1));

	}else if(j==al.BWIDTH-1) {

	    constantM = Math.log(sumOverChars(i+1,j));
	    constantX = Math.log(sumOverCharsX(i+1,j-1));
	    constantY = Double.NEGATIVE_INFINITY;

	} else {

	    constantM = Math.log(sumOverChars(i+1,j));
	    constantX = Math.log(sumOverCharsX(i+1,j-1));
	    constantY = Math.log(sumOverCharsY(i,j+1));

	}

	if(j==0) { // upper border of the band; an X-gap impossible
	    return tl.sumLogs(logMinus2Delta+constantM+al.bwdM[i+1][j],
			      logDelta+constantY+al.bwdY[i][j+1]);
	}

	if(j==al.BWIDTH-1) { // lower border of the band; a Y-gap impossible 
	    return tl.sumLogs(logMinus2Delta+constantM+al.bwdM[i+1][j],
			      logDelta+constantX+al.bwdX[i+1][j-1]);
	}
 
	return  tl.sumLogs(logMinus2Delta+constantM+al.bwdM[i+1][j],
			   logDelta+tl.sumLogs(constantX+al.bwdX[i+1][j-1],
					       constantY+al.bwdY[i][j+1]));
			       
    }

    double getBackwardX(int i, int j) {

	if(j==0) { // upper border of the band; an X-gap impossible
	    return logMinusEpsilon+constantM+al.bwdM[i+1][j];   
	}
	
	return tl.sumLogs(logMinusEpsilon+constantM+al.bwdM[i+1][j],
			   logEpsilon+constantX+al.bwdX[i+1][j-1]);
    }

    double getBackwardY(int i, int j) {

	if(j==al.BWIDTH-1) { // lower border of the band; a Y-gap impossible
	    return logMinusEpsilon+constantM+al.bwdM[i+1][j];  
	}

	return  tl.sumLogs(logMinusEpsilon+constantM+al.bwdM[i+1][j],
			   logEpsilon+constantY+al.bwdY[i][j+1]);
    }

 
   // sum over all possible characters [k] at the node [i][j]

    double sumOverChars(int i, int j){

	double sum = 0d;
	for(int k = 0; k<aSize; k++) {
	    sum += al.price[i][j][k];
	}
	return sum;
    }

    double sumOverCharsX(int i, int j){

	double sum = 0d;
	for(int k = 0; k<aSize; k++) {
	    sum += al.priceX[i][j][k];
	}
	return sum;
    }

    double sumOverCharsY(int i, int j){
    
	double sum = 0d;
	for(int k = 0; k<aSize; k++) {
	    sum += al.priceY[i][j][k];
	}
	return sum;
    }


    double maxOfTwoX(double m, double x) {
	double max = 0d;
        if(m > x) {
	    max = m;
	} else {
	    max = x;
	}
	
	return max;
    }

    double maxOfTwoY(double m, double y) {
	double max = 0d;
	if(m > y) {
	    max = m;
	} else {
	    max = y;
	}
	return max;
    }

    double maxOfThree(double m, double x, double y) {

	double max = 0d;
	if(m > x && m > y) {
	    max = m;
	} else if(x > y){
	    max = x;
	} else {
	    max = y;
	}

	return max;
    }
}











