/**
 * Title:        ProAlign<p>
 * Description:  <p>
 * Copyright:    Copyright (c) Ari Loytynoja<p>
 * License:      GNU GENERAL PUBLIC LICENSE<p>
 * @see          http://www.gnu.org/copyleft/gpl.html
 * Company:      ULB<p>
 * @author Ari Loytynoja
 * @version 1.0
 */
package proalign;

class PwAlignment {

    int len1,len2;
    int[][] matM,matX,matY,pointM,pointX,pointY;
    String[] rev = new String[2];
    String seq1, seq2;
    
    int[][] subst;
    int gOpen;
    int gExt;
    String alphabet;
    boolean isDna;
    
    PwAlignment(int[][] substTable, int gapOpen, int gapExt, String alpha, boolean isD) {  

	ProAlign.log("PwAlignment");

	subst = substTable;
	gOpen = gapOpen;
	gExt = gapExt;
	alphabet = alpha;
	isDna = isD;

    }

    String[] revAligned(String s1, String s2) {

	ProAlign.log("PwAlignment");
		    
	len1 = s1.length();
	len2 = s2.length();
	seq1 = " "+s1;
	seq2 = " "+s2;

	initializeMatrices();

	pwAlignment();

	return rev;
    }

    double align(String s1, String s2) {

	len1 = s1.length();
	len2 = s2.length();
	seq1 = " "+s1;
	seq2 = " "+s2;

	initializeMatrices();

	pwAlignment();


	// Look for terminal gaps
	//
	int first = 0; int last = 0;
	for(int i=0; i<rev[0].length(); i++) {
	    if(rev[0].charAt(i)=='-' || rev[1].charAt(i)=='-') {
		continue;
	    } else {
		first = i;
		break;
	    }
	}
	for(int i=rev[0].length()-1; i>=0; i--) {
	    if(rev[0].charAt(i)=='-' || rev[1].charAt(i)=='-') {
		continue;
	    } else {
		last = i;
		break;
	    }
	}
	
	// Count identities
	//
	int all = 0;
	int same = 0;
	for(int i=first; i<=last; i++) {
	    if(rev[0].charAt(i)==rev[1].charAt(i)) {
		same++;
	    }
	    all++;
	} 
	
	// Print the alignment & path
	//
	boolean printAlignment = false;
	if(printAlignment) {
	    for(int i=rev[0].length()-1; i>=0; i--) {
		System.out.print(rev[0].charAt(i));
	    }
	    System.out.println();
	    for(int i=rev[1].length()-1; i>=0; i--) {
		System.out.print(rev[1].charAt(i));
	    }
	    System.out.println();
	    System.out.println("same "+same+", all "+all);
	}

	// Return the distance
	//
	if(isDna) {
	    double p = 1d-(double)same/(double)all; 
	    double jcK;
	    if(p>0.75d) {
		jcK=5d;
	    } else if(ProAlign.correctMultiple) {
		jcK = -0.75d*Math.log(1d-4d/3d*p);
	    } else {
		jcK = p;
	    }
	    if(jcK>5d) {
		jcK=5d;
	    }
	    return jcK;
	} else {
	    double p = 1d-(double)same/(double)all; 
	    double kD;
	    if(p>0.85d) {
		kD=5d;
	    } else if(ProAlign.correctMultiple) {
		kD = -1d*Math.log(1-p-0.2d*p*p);
	    } else {
		kD = p;
	    }
	     if(kD>5d) {
		kD=5d;
	    }
	    return kD;
	}
    }

    int[] trailing(String s1, String s2) {

	len1 = s1.length();
	len2 = s2.length();
	seq1 = " "+s1;
	seq2 = " "+s2;

	initializeMatrices();

	pwAlignment();

	// Look for terminal gaps
	//
	int[] trail = new int[2];
	int c1=0,c2=0;

	for(int i=0; i<rev[0].length(); i++) {
	    if(rev[0].charAt(i)!='-') {
		c1++;
	    }
	    if(rev[1].charAt(i)!='-') {
		c2++;
	    }

	    if(rev[0].charAt(i)=='-' || rev[1].charAt(i)=='-') {
		continue;
	    } else {
		trail[1] = c1-c2;
		break;
	    }
	}
	c1=0;
	c2=0;
	for(int i=rev[0].length()-1; i>=0; i--) {
	    if(rev[0].charAt(i)!='-') {
		c1++;
	    }
	    if(rev[1].charAt(i)!='-') {
		c2++;
	    }
	    if(rev[0].charAt(i)=='-' || rev[1].charAt(i)=='-') {
		continue;
	    } else {
		trail[0] = c1-c2;
		break;
	    }
	}

	return trail;
    }

    void pwAlignment() {

	// Fill the alignment tables
	//
	for(int i=0; i<=len1; i++) {
	    for(int j=0; j<=len2; j++) {

		if(i==0 && j==0 ) {
		    continue;
		}

		if(i>0 && j>0) {
		    int match = subst[alphabet.indexOf(seq1.charAt(i))][alphabet.indexOf(seq2.charAt(j))];
		    if(matM[i-1][j-1] >= matX[i-1][j-1] && matM[i-1][j-1] >= matY[i-1][j-1]) {
			matM[i][j] = matM[i-1][j-1]+match;
			pointM[i][j] = 0;
		    } else if(matX[i-1][j-1] >= matY[i-1][j-1]) {
			matM[i][j] = matX[i-1][j-1]+match;
			pointM[i][j] = 1;
		    } else {
			matM[i][j] = matY[i-1][j-1]+match;
			pointM[i][j] = 2;
		    }
		}

		if(j==0 && i>0 && !ProAlign.penalizeTerminal) {
		    matX[i][j] = matX[i-1][j];
		    pointX[i][j] = 1;
		} else if(j==len2 && i>0 && !ProAlign.penalizeTerminal) {
		    if(matM[i-1][j] >= matX[i-1][j]) {
			matX[i][j] = matM[i-1][j];
			pointX[i][j] = 0;
		    } else { 
			matX[i][j] = matX[i-1][j];
			pointX[i][j] = 1;
		    }
		} else if(i>0) {
		    if(matM[i-1][j]+gOpen >= matX[i-1][j]+gExt) {
			matX[i][j] = matM[i-1][j]+gOpen;
			pointX[i][j] = 0;
		    } else {
			matX[i][j] = matX[i-1][j]+gExt;
			pointX[i][j] = 1;
		    }
		}

		if(i==0 && j>0 && !ProAlign.penalizeTerminal) {
		    matY[i][j] = matY[i][j-1];
		    pointY[i][j] = 2;
		} else if(i==len1 && j>0 && !ProAlign.penalizeTerminal) {
		    if(matM[i][j-1] >= matY[i][j-1]) {
			matY[i][j] = matM[i][j-1];
			pointY[i][j] = 0;
		    } else {
			matY[i][j] = matY[i][j-1];
			pointY[i][j] = 2;
		    }
		} else if(j>0) {
		    if(matM[i][j-1]+gOpen >= matY[i][j-1]+gExt) {
			matY[i][j] = matM[i][j-1]+gOpen;
			pointY[i][j] = 0;
		    } else {
			matY[i][j] = matY[i][j-1]+gExt;
			pointY[i][j] = 2;
		    }
		}
	    }
	}


	// Look for best end path & score
	//
	int end = 0; int point = 0;
	    
	if(matM[len1][len2] >= matX[len1][len2] && matM[len1][len2] >= matY[len1][len2]) {
	    end = matM[len1][len2];
	    point = 0;
	} else if(matX[len1][len2] >= matY[len1][len2]) {
	    end = matX[len1][len2];
	    point = 1;
	} else {
	    end = matY[len1][len2];
	    point = 2;
	}

	// Trace back the path
	//
	rev[0] = "";
	rev[1] = "";

	int i=len1; int j=len2;

	while(i>0 || j>0) {

	    if(point==0) {
		rev[0]+=seq1.charAt(i);
		rev[1]+=seq2.charAt(j);
		point = pointM[i][j];
		i--;j--;
		continue;

	    } else if(point==1) {
		rev[0]+=seq1.charAt(i);
		rev[1]+="-";
		point = pointX[i][j];
		i--;
		continue;

	    } else if(point==2) {
		rev[0]+="-";
		rev[1]+=seq2.charAt(j);
		point = pointY[i][j];
		j--;
		continue;

	    } else {
		System.out.println("wrong pointer!");
		break;
	    }
	}
    }


    // Initialize matrices for score and pointers
    //
    void initializeMatrices() {
	
	matM = new int[len1+1][len2+1];
	matX = new int[len1+1][len2+1];
	matY = new int[len1+1][len2+1];
	pointM = new int[len1+1][len2+1];
	pointX = new int[len1+1][len2+1];
	pointY = new int[len1+1][len2+1];
	
	int small = -100000;

	for(int j=1; j<matM.length; j++) {
	    matM[j][0] = small;
	    matX[j][0] = small;
	    matY[j][0] = small;
	}
	for(int i=1; i<matM[0].length; i++) {
	    matM[0][i] = small;
	    matX[0][i] = small;
	    matY[0][i] = small;
	}
    }
}


