/**
 * Title:        ProAlign<p>
 * Description:  <p>
 * Copyright:    Copyright (c) Ari Loytynoja<p>
 * License:      GNU GENERAL PUBLIC LICENSE<p>
 * @see          http://www.gnu.org/copyleft/gpl.html
 * Company:      ULB<p>
 * @author Ari Loytynoja
 * @version 1.0
 */
package proalign;

import java.util.Random;

class TraceBackPath {

    AlignmentLoop al;
    AlignmentNode an;
    Random r;

    int warnLimit = 10;
    int stopLimit = 3;
    boolean isBandWarning = false;

    int state = 0;
    int aSize;
    int[][] cellPath;
    double[] postProb;

    double siteSum;
    double trailProb;
    double logDelta;
    double logEpsilon;
    double logMinusEpsilon;

    int sampleTimes = 0;
    boolean isUnique = true;
    boolean BEST;

    TraceBackPath(AlignmentNode an, AlignmentLoop al) {
	
	this.an = an;
	this.al = al;
        r = new Random();

	logDelta = Math.log(al.pa.sm.delta);
	logEpsilon = Math.log(al.pa.sm.epsilon);
	logMinusEpsilon = Math.log(1-al.pa.sm.epsilon);

	ProAlign.log("TraceBackPath");
    }

    double[][] getNode(boolean best) throws TraceBackException {

	this.BEST = best;
	this.aSize = al.aSize;
	char path = getEndPath();
	int i=al.pathM.length-1; // back trace strting point (i,j)
	int j=al.endPoint;

	double[][] cProb = new double[i+j][aSize]; // char. probs. on the path
	double[] pProb = new double[i+j];          // post. probs. on the path
	int[][] cell = new int[i+j][2];
	int ppc = 0;
	
	// go through the matrix;
	// start from end, stop when at (1,1)
	while(true) {

	    if(j<stopLimit || j>ProAlign.bandWidth-stopLimit) {
		throw new TraceBackException("traceback path comes too close to the band edge.");

	    } else if(j<warnLimit || j>ProAlign.bandWidth-warnLimit) {
		isBandWarning = true;
	    }

	    if(i==1 && j==al.MIDDLE) { // currently(?) stops at middle
		break;
	    } 
	   
	    if(state==0) { // -> match
		
		path = getPathM(i,j);

		for(int l=0; l<aSize; l++) {
		    cProb[ppc][l] = al.price[i][j][l];
		}
		cell[ppc][0] = i;
		cell[ppc][1] = j-al.MIDDLE+i;
		pProb[ppc++] = (al.fwdM[i][j]+al.bwdM[i][j])-al.fwdEnd;
		i--;
		
		if(path=='X') {
		    state = 1;
		}else if(path=='Y') {
		    state = 2;
		}
	    }

	    
	    else if(state==1) { // -> gap X

		path = getPathX(i,j);

		for(int l=0; l<aSize; l++) {
		    cProb[ppc][l] = al.priceX[i][j][l];
		}
		cell[ppc][0] = i;
		cell[ppc][1] = -1;
		pProb[ppc++] = (al.fwdX[i][j]+al.bwdX[i][j])-al.fwdEnd;

		i--;j++;

		if(path=='m') {
		    state = 0;
		}
	    }

	    else if(state==2) { // -> gap Y
		
		path = getPathY(i,j);
		for(int l=0; l<aSize; l++) {
		    cProb[ppc][l] = al.priceY[i][j][l];
		}

		cell[ppc][0] = -1;
		cell[ppc][1] = j-al.MIDDLE+i;
		pProb[ppc++] = (al.fwdY[i][j]+al.bwdY[i][j])-al.fwdEnd;

		j--;
		if(path=='m') {
		    state = 0;
		}
	    }
	}

	double[][] charProb;

	if(an.hasTrailers) { // trailing sequence(s) cut out 

	    int start = Math.max(an.start0,an.start1);
	    int end = Math.max(an.end0,an.end1);
	    int length = ppc+start+end;
	    trailProb = 0d;
	    charProb = new double[length][aSize];
	    cellPath = new int[length][2];
	    postProb = new double[length];
	    if(start>0) {
		if(an.start0>an.start1) {
		    for(int k=0; k<start; k++) {
			charProb[k] = trailerCharProb1(k);
			cellPath[k][0] = k+2;
			cellPath[k][1] = -1;
			postProb[k] = 0d;
			trailProb += Math.log(siteSum)+logEpsilon;
		    }
		    trailProb = trailProb+logMinusEpsilon-logEpsilon;

		    for(int k=0; k<ppc; k++) {
			double sum = 0d;
			for(int l=0; l<aSize; l++) {
			    sum += cProb[ppc-k-1][l];
			}
			for(int l=0; l<aSize; l++) {
			    charProb[k+start][l] =  cProb[ppc-k-1][l]/sum;
			}
			if(cell[ppc-k-1][0]>0) {
			    cellPath[k+start][0] = cell[ppc-k-1][0]+start;
			} else {
			    cellPath[k+start][0] = cell[ppc-k-1][0];
			}
			cellPath[k+start][1] = cell[ppc-k-1][1];
			postProb[k+start] = pProb[ppc-k-1];
		    }

		} else {
		    for(int k=0; k<start; k++) {
			charProb[k] = trailerCharProb2(k);
			cellPath[k][0] = -1;
			cellPath[k][1] = k+2;
			postProb[k] = 0d;
			trailProb += Math.log(siteSum)+logEpsilon;
		    }
		    trailProb = trailProb+logMinusEpsilon-logEpsilon;

		    for(int k=0; k<ppc; k++) {
			double sum = 0d;
			for(int l=0; l<aSize; l++) {
			    sum += cProb[ppc-k-1][l];
			}
			for(int l=0; l<aSize; l++) {
			    charProb[k+start][l] =  cProb[ppc-k-1][l]/sum;
			}
			cellPath[k+start][0] = cell[ppc-k-1][0];
			if(cell[ppc-k-1][1]>0) {
			    cellPath[k+start][1] = cell[ppc-k-1][1]+start;
			} else {
			    cellPath[k+start][1] = cell[ppc-k-1][1];
			}
			postProb[k+start] = pProb[ppc-k-1];
		    }
		}
	    } else {
		
		for(int k=0; k<ppc; k++) {
		    double sum = 0d;
		    for(int l=0; l<aSize; l++) {
			sum += cProb[ppc-k-1][l];
		    }
		    for(int l=0; l<aSize; l++) {
			charProb[k][l] =  cProb[ppc-k-1][l]/sum;
		}
		    cellPath[k][0] = cell[ppc-k-1][0];
		    cellPath[k][1] = cell[ppc-k-1][1];
		    postProb[k] = pProb[ppc-k-1];
		}	
		
	    }
	    if(end>0) {
		if(an.end0>an.end1) {

		    for(int k=0; k<end; k++) {
			int p = an.child[0].charProb.length-end+k;
			charProb[ppc+start+k] = trailerCharProb1(p);
			cellPath[ppc+start+k][0] = p+2;
			cellPath[ppc+start+k][1] = -1;
			postProb[ppc+start+k] = 0d;	
			trailProb += Math.log(siteSum)+logEpsilon;
		    }

		    if(getEndPath()=='M') {
			trailProb = trailProb+logDelta-logEpsilon;
		    }
			
		} else {

		    for(int k=0; k<end; k++) {
			int p = an.child[1].charProb.length-end+k;
			charProb[ppc+start+k] = trailerCharProb2(p);
			cellPath[ppc+start+k][0] = -1;
			cellPath[ppc+start+k][1] = p+2;
			postProb[ppc+start+k] = 0d;
			trailProb += Math.log(siteSum)+logEpsilon;
		    }

		    if(getEndPath()=='M') {
			trailProb = trailProb+logDelta-logEpsilon;
		    }
		}
	    }
	    return charProb;

	} else {   // no trailing sequences removed

	    charProb = new double[ppc][aSize];

	    cellPath = new int[ppc][2];
	    postProb = new double[ppc];
	
	    for(int k=0; k<ppc; k++) {
		double sum = 0d;
		for(int l=0; l<aSize; l++) {
		    sum += cProb[ppc-k-1][l];
		}
		for(int l=0; l<aSize; l++) {
		    charProb[k][l] =  cProb[ppc-k-1][l]/sum;
		}
		cellPath[k][0] = cell[ppc-k-1][0];
		cellPath[k][1] = cell[ppc-k-1][1];
		postProb[k] = pProb[ppc-k-1];
	    }
	    return charProb;
	}
    }

    // get path to go backwards;
    // select either best (viterbi) or sample from probabilities

    char getEndPath() {

	char path = ' ';
	double pm = al.pathEnd[0];
	double px = al.pathEnd[1];
	double py = (double) 1d - pm - px;

	// select viterbi path; if equally good, take one randomly
	if(BEST) { 
	    if(pm > px && pm > py) {
		path = 'M';
	    } else if(pm > py && pm == px) {
		if(r.nextBoolean()) {
		    path = 'M';
		} else {
		    path = 'x';
		    state = 1;
		}
		isUnique = false;
		sampleTimes++;
	    } else if(pm > px && pm == py) {
		if(r.nextBoolean()) {
		    path = 'M';
		} else {
		    path = 'y';
		    state = 2;
		}
		isUnique = false;
		sampleTimes++;
	    } else if(px > py) {
		path = 'x';
		state = 1;
	    } else if(px == py) {
		if(r.nextBoolean()) {
		    path = 'x';
		    state = 1;
		} else {
		    path = 'y';
		    state = 2;
		}
		isUnique = false;
		sampleTimes++;
	    } else {
		path = 'y';
		state = 2;
	    }
	    
	// sample from paths according to their probabilities
 	} else {
	    double rdn = r.nextDouble();
	    if(rdn < pm) {
		path = 'M';
	    } else if(rdn < (pm+px)) {
		path = 'X';
		state = 1;
	    } else {
		path = 'Y';
		state = 2;
	    }
	}

	return path;
    }

    char getPathM(int i, int j) {

	char path = ' ';
	double pm = al.pathM[i][j][0];
	double px = al.pathM[i][j][1];
	double py = (double) 1d - pm - px;

	// select viterbi path; if equally good, take one randomly
	if(BEST) { 
	    if(pm > px && pm > py) {
		path = 'M';
	    } else if(pm > py && pm == px) {
		if(r.nextBoolean()) {
		    path = 'M';
		} else {
		    path = 'X';
		}
		isUnique = false;
		sampleTimes++;
	    } else if(pm > px && pm == py) {
		if(r.nextBoolean()) {
		    path = 'M';
		} else {
		    path = 'Y';
		}
		isUnique = false;
		sampleTimes++;
	    } else if(px > py) {
		path = 'X';
	    } else if(px == py) {
		if(r.nextBoolean()) {
		    path = 'X';
		} else {
		    path = 'Y';
		}
		isUnique = false;
		sampleTimes++;
	    } else {
		path = 'Y';
	    }
	    
	// sample from paths according to their probabilities
 	} else {
	    double rdn = r.nextDouble();
	    if(rdn < pm) {
		path = 'M';
	    } else if(rdn < (pm+px)) {
		path = 'X';
	    } else {
		path = 'Y';
	    }
	}

	return path;
    } 

    char getPathX(int i, int j) {

	char path = ' ';
	double pm = al.pathX[i][j];
	double px = (double) 1d - pm;

	if(BEST) {
	    if(pm > px) {
		path = 'm'; 
	    } else if( pm == px) {
		if(r.nextBoolean()) {
		    path = 'm';
		} else {
		    path = 'x';
		}
		isUnique = false;
		sampleTimes++;
	    } else {
		path = 'x';
	    }

	} else {
	    double rdn = r.nextDouble();
	    if(rdn < pm) {
		path = 'm';
	    } else {
		path = 'x';
	    }
	}

	return path;
    } 
    
    char getPathY(int i, int j) {
	
	char path = ' ';
	double pm = al.pathY[i][j];
	double py = (double) 1d - pm;

	if(BEST) {
	    if(pm > py) {
		path = 'm'; 
	    } else if( pm == py) {
		if(r.nextBoolean()) {
		    path = 'm';
		} else {
		    path = 'y';
		}
		isUnique = false;
		sampleTimes++;
	    } else {
		path = 'y';
	    }

	} else {
	    double rdn = r.nextDouble();
	    if(rdn < pm) {
		path = 'm';
	    } else {
		path = 'y';
	    }
	}

	return path;
    }

    double[] trailerCharProb1(int i) {
	
	double[] cProb = new double[aSize];
	siteSum = 0d;
	for(int k=0; k<aSize; k++) {
	    for(int l = 0; l<aSize; l++) {
		cProb[k] += an.child[0].charProb[i][l]*al.pa.sm.substProb1[k][l];
	    }
	    cProb[k] = cProb[k]*al.pa.sm.substProb2[k][aSize-1]*al.pa.sm.charFreqs[k];
	    siteSum += cProb[k];
	}
	for(int k=0; k<aSize; k++) {
	    cProb[k] = cProb[k]/siteSum;
	}
	return cProb;
    }

    double[] trailerCharProb2(int i) {
	
	double[] cProb = new double[aSize];
	siteSum = 0d;
	for(int k=0; k<aSize; k++) {
	    for(int l = 0; l<aSize; l++) {
		cProb[k] += an.child[1].charProb[i][l]*al.pa.sm.substProb2[k][l];
	    }
	    cProb[k] = cProb[k]*al.pa.sm.substProb1[k][aSize-1]*al.pa.sm.charFreqs[k];
	    siteSum += cProb[k];
	}
	for(int k=0; k<aSize; k++) {
	    cProb[k] = cProb[k]/siteSum;
	}
	return cProb;
    }
}

class TraceBackException extends Exception {

    public TraceBackException() {}
    public TraceBackException(String msg) {
	super(msg);
    }
}












