/**
 * Title:        ProAlign<p>
 * Description:  <p>
 * Copyright:    Copyright (c) Ari Loytynoja<p>
 * License:      GNU GENERAL PUBLIC LICENSE<p>
 * @see          http://www.gnu.org/copyleft/gpl.html
 * Company:      ULB<p>
 * @author Ari Loytynoja
 * @version 1.0
 */
package proalign;

class AlignmentLoop {

/*
  Do the viterbi path, and mark the track.
  Do forward & backward algorithm, calculate posterior probs
*/

    ProAlign pa;
    AlignmentLoop al;
    AlignmentNode an;

    double[][][] price, priceX, priceY;
    double[][] vitM, vitX, vitY;

    double[][][] pathM;
    double[][] pathX, pathY;

    double[][] fwdM, fwdX, fwdY;
    double[][] bwdM, bwdX, bwdY;

    double vitEnd, fwdEnd;
    double[] pathEnd;

    double[][] seq1, seq2;
    float dist1, dist2;

    int aSize;

    int BWIDTH;
    int MIDDLE;

    int endPoint = 0;

    AlignmentLoop(ProAlign pa,AlignmentNode an) {

	this.pa = pa;
	this.an = an;
	al = this;

	BWIDTH = ProAlign.bandWidth;
	MIDDLE = BWIDTH/2+1;
    }
    
    void align(double[][] s1, double[][] s2, float d1, float d2) throws Exception {

	seq1 = s1;
	seq2 = s2;
	dist1 = d1;
	dist2 = d2;

	al.pa.sm.setBranchLength(dist1,dist2);

	aSize = seq1[0].length; // aSize = alphabet size; 

	ProAlign.log("AlignmentLoop");
	ProAlign.log(" seq1.length = "+seq1.length+", seq2.length = "+seq2.length);
	
	try {
	    al.initialiseMatrices();
	} catch(Error e) {
	    if(ProAlign.isResultWindow) {
		String text = "\n        Out Of Memory Error!\n  Please, increase JVM memory.\n";
		OpenDialog od = new OpenDialog(pa.rw);
		od.showDialog("Error!", text);
		System.exit(0);
	    } else {
		ProAlign.log.println("AlignmentLoop: Out Of Memory Error. Increase JVM memory.");
		System.out.println("Out Of Memory Error.\nPlease, increase JVM memory.");
		if(ProAlign.exitInError) {
		    System.exit(0);
		} else {
//		    System.out.println("AlignmentLoop: throws an exception");
		    throw new OutOfMemoryException("Out Of Memory Error. Increase JVM memory");
		}
	    }
	}

	Viterbi v = new Viterbi(al);

	// do viterbi loop through the matrix;
	// at the same, do forward and mark the path
	// for the traceback

	for(int i=1; i<vitM.length; i++) { // go through seq1
	    for(int j=0; j<BWIDTH; j++) {  // go through seq2

		if(j-MIDDLE+i < 1) { continue; }     // upper border of real matrix

		if(j-MIDDLE+i > seq2.length+1) { continue; }  // lower border of real matrix

		if(i==1 && j==MIDDLE) { continue; }  // starting point

		// exception if seq2 starts later; only X-gaps possible

		if(j-MIDDLE+i==1) {   

		    sumSubstPriceX(i,j);

		    pathM[i][j][0] = -1d;
		    pathM[i][j][1] = -1d;  
		    vitX[i][j] = v.getViterbiX(i,j);
		    pathY[i][j] = -1d;
		    fwdX[i][j] = v.getForwardX(i,j);

		    endPoint = j;
		} 

		// exception if seq1 starts later; only Y-gaps possible
		else if(i==1) {   

		    sumSubstPriceY(i,j);

		    pathM[i][j][0] = -1d;
		    pathM[i][j][1] = -1d;
		    pathX[i][j] = -1d;
		    vitY[i][j] = v.getViterbiY(i,j);
		    fwdY[i][j] = v.getForwardY(i,j);

		    endPoint = j;
		}
		
		else {
		    
		    // go through alphabet, get probability for each character 
		    sumSubstPrice(i,j);

		    vitM[i][j] = v.getViterbiM(i,j); // for 'Match'
		    vitX[i][j] = v.getViterbiX(i,j); // for 'X indel'
		    vitY[i][j] = v.getViterbiY(i,j); // for 'Y indel'
		    fwdM[i][j] = v.getForwardM(i,j);
		    fwdX[i][j] = v.getForwardX(i,j);
		    fwdY[i][j] = v.getForwardY(i,j);

		    endPoint = j;
		}
	    }
	}
	vitEnd = v.getViterbiEnd(endPoint);
	
	try {
	    al.initialiseBwdMatrices(endPoint);
	} catch(Error e) {
	    if(ProAlign.isResultWindow) {
		String text = "\n        Out Of Memory Error!\n  Please, increase JVM memory.\n";
		OpenDialog od = new OpenDialog(pa.rw);
		od.showDialog("Error!", text);
		System.exit(0);
	    } else {
		ProAlign.log.println("AlignmentLoop: Out Of Memory Error. Increase JVM memory.");
		System.out.println("Out Of Memory Error.\nPlease, increase JVM memory.");
		if(ProAlign.exitInError) {
		    System.exit(0);
		} else {
//		    System.out.println("AlignmentLoop: throws an exception");
		    throw new OutOfMemoryException("Out Of Memory Error. Increase JVM memory");
		}
	    }
	}

	// do backward loop through the matrix;
	for(int i=vitM.length-1; i>0; i--) {            // go through seq1
	    for(int j=BWIDTH-1; j>=0; j--) {            // go through seq2

		if(j-MIDDLE+i < 1) { continue; }        // upper border of real matrix

		else if(j-MIDDLE+i > seq2.length+1) { continue; } // lower border of real matrix

		else if(i==vitM.length-1 && j==endPoint) { continue; }  // starting point

		else {
		    bwdM[i][j] = v.getBackwardM(i,j);
		    bwdX[i][j] = v.getBackwardX(i,j);
		    bwdY[i][j] = v.getBackwardY(i,j);
		}
	    }
	}
	if(ProAlign.DEBUG) { // ---DEBUG---

	    TransformLog tl = new TransformLog();
	    double bwdEnd = tl.sumLogs(bwdM[1][MIDDLE],tl.sumLogs(bwdX[1][MIDDLE],
								  bwdY[1][MIDDLE]));
	    double diff = (double) bwdEnd-fwdEnd;
	    ProAlign.log.println("AlignmentLoop: vEnd: "+vitEnd+" fEnd: "+fwdEnd+
			" bEnd: "+bwdEnd+" fwd-bwd: "+diff);
	} // ---DEBUG---
    }

    // initialise all matrices before starting
    void initialiseMatrices() {

	// [seq1 length][band width][alphabet]
	price = new double[seq1.length+2][BWIDTH][aSize+1];
	priceX = new double[seq1.length+2][BWIDTH][aSize+1];
	priceY = new double[seq1.length+3][BWIDTH][aSize+1];

	vitM = new double[seq1.length+2][BWIDTH];
	vitX = new double[seq1.length+2][BWIDTH];
	vitY = new double[seq1.length+2][BWIDTH];

	pathM = new double[seq1.length+2][BWIDTH][2];
	pathX = new double[seq1.length+2][BWIDTH];
	pathY = new double[seq1.length+2][BWIDTH];

	fwdM = new double[seq1.length+2][BWIDTH];
	fwdX = new double[seq1.length+2][BWIDTH];
	fwdY = new double[seq1.length+2][BWIDTH];

	bwdM = new double[seq1.length+3][BWIDTH];
	bwdX = new double[seq1.length+3][BWIDTH];
	bwdY = new double[seq1.length+3][BWIDTH];

	pathEnd = new double[2];

	for(int j=0; j<vitM[0].length; j++) {
	    for(int i=0; i<vitM.length; i++) {
		vitM[i][j] = Double.NEGATIVE_INFINITY;
		vitX[i][j] = Double.NEGATIVE_INFINITY;
		vitY[i][j] = Double.NEGATIVE_INFINITY;
		fwdM[i][j] = Double.NEGATIVE_INFINITY;
		fwdX[i][j] = Double.NEGATIVE_INFINITY;
		fwdY[i][j] = Double.NEGATIVE_INFINITY;
	    }
	    
	    for(int i=0; i<bwdM.length; i++) {
		bwdM[i][j] = Double.NEGATIVE_INFINITY;
		bwdX[i][j] = Double.NEGATIVE_INFINITY;
		bwdY[i][j] = Double.NEGATIVE_INFINITY;
	    }
	}

	if(an.hasTrailers) {

	    if(an.start0>an.start1) {
		vitM[1][MIDDLE] = 0d;
		vitX[1][MIDDLE] = 0d;
		vitY[1][MIDDLE] = Double.NEGATIVE_INFINITY;
		
		fwdM[1][MIDDLE] = 0d;
		fwdX[1][MIDDLE] = 0d;
		fwdY[1][MIDDLE] = Double.NEGATIVE_INFINITY;		
	    } else {
		vitM[1][MIDDLE] = 0d;
		vitX[1][MIDDLE] = Double.NEGATIVE_INFINITY;
		vitY[1][MIDDLE] = 0d;
		
		fwdM[1][MIDDLE] = 0d;
		fwdX[1][MIDDLE] = Double.NEGATIVE_INFINITY;
		fwdY[1][MIDDLE] = 0d;
	    }

	} else {
	    //  ([0] means actually [-1]) 
	    vitM[1][MIDDLE] = 0d;
	    vitX[1][MIDDLE] = 0d;
	    vitY[1][MIDDLE] = 0d;
	    
	    fwdM[1][MIDDLE] = 0d;
	    fwdX[1][MIDDLE] = 0d;
	    fwdY[1][MIDDLE] = 0d;
	}
    }

    void initialiseBwdMatrices(int endPoint) {

	bwdM[bwdM.length-2][endPoint] = 0d;
	bwdX[bwdX.length-2][endPoint] = 0d;
	bwdY[bwdY.length-2][endPoint] = 0d;
    }
    


    // calculate the sum of substitution prices for both branches separately;
    // iterate over all possible characters at the node, child 1 and child 2;
    // characters at [i] get values 'l'; at [j] values 'm'; node gets characters 'k' 

    void sumSubstPrice(int i, int j){

	// character at node
	int js = i+j-MIDDLE;  // convert from "band" value to real value

	double[][] sum = new double[2][aSize];
	
	for(int k=0; k<aSize; k++) {
	    sum[0][k] = 0d;
	    sum[1][k] = 0d;
	    for(int l = 0; l<aSize; l++) {
		// match
		// characters at child 1; correction [i-2] for "positive" matrix
		sum[0][k] += seq1[i-2][l]*al.pa.sm.substProb1[k][l];
		// characters at child 2; correction [j-2] for "positive" matrix
		sum[1][k] += seq2[js-2][l]*al.pa.sm.substProb2[k][l];
		
		// x-gap
		priceX[i][j][k] += al.seq1[i-2][l]*al.pa.sm.substProb1[k][l];
		
		// y-gap
		priceY[i][j][k] += al.seq2[js-2][l]*al.pa.sm.substProb2[k][l];
	    }
	    
	    price[i][j][k] = sum[0][k]*sum[1][k]*al.pa.sm.charFreqs[k];
	    // because the other child is "gap" for sure
	    priceX[i][j][k] = priceX[i][j][k]*1.0d*al.pa.sm.substProb2[k][aSize-1]*al.pa.sm.charFreqs[k];
	    priceY[i][j][k] = priceY[i][j][k]*1.0d*al.pa.sm.substProb1[k][aSize-1]*al.pa.sm.charFreqs[k];
	}
	
    	return;
    }

    void sumSubstPriceX(int i, int j){

	for(int k=0; k<aSize; k++) {
	    for(int l = 0; l<aSize; l++) { 
		// x-gap
		priceX[i][j][k] += al.seq1[i-2][l]*al.pa.sm.substProb1[k][l];
	    }   
	    
	    priceX[i][j][k] = priceX[i][j][k]*1.0d*al.pa.sm.substProb2[k][aSize-1]*al.pa.sm.charFreqs[k];
	}	    
	return; 
    }

    void sumSubstPriceY(int i, int j){

	int js = i+j-MIDDLE;  // convert from "band" value to real value
	
	for(int k=0; k<aSize; k++) { 
	    for(int l = 0; l<aSize; l++) { 
		// y-gap
		priceY[i][j][k] += al.seq2[js-2][l]*al.pa.sm.substProb2[k][l];
	    }   
	    
	    priceY[i][j][k] = priceY[i][j][k]*1.0d*al.pa.sm.substProb1[k][aSize-1]*al.pa.sm.charFreqs[k];
	}
	return; 
    }


}





