/**
 * Title:        ProAlign<p>
 * Description:  <p>
 * Copyright:    Copyright (c) Ari Loytynoja<p>
 * License:      GNU GENERAL PUBLIC LICENSE<p>
 * @see          http://www.gnu.org/copyleft/gpl.html
 * Company:      ULB<p>
 * @author Ari Loytynoja
 * @version 1.0
 */
package proalign;

import javax.swing.JOptionPane;
import java.io.File;
import java.io.IOException;
import java.util.Calendar;
import java.util.Date;

public class SaveData {

    boolean[] sitesRemoved;
    boolean[] taxaRemoved;
    
    String filename;  
    String resultString;
    
    String[] sequenceNames;
    String[] sequenceData;

    String spaceStr = "                         ";
    int taxaNum;
    int seqLength;
    
    int numTaxa;
    int numSite;

    AlignmentNode root;

    double[] minProb;

//
    public SaveData(String file, int type, AlignmentNode root) {
	
	ProAlign.log("SaveData");

	this.root = root;
	resultString = "vitEnd: "+root.sumViterbiEnd()+", fwdEnd: "+root.sumForwardEnd();

	sequenceNames = root.getTerminalNames();
	for(int i=0; i<sequenceNames.length; i++) {
	    sequenceNames[i] = (sequenceNames[i]+spaceStr).substring(0,21);
	}

	sequenceData = new String[root.getNumChild()];
	for(int i=0; i<sequenceData.length; i++) { sequenceData[i] = ""; }
	for(int i=0; i<root.cellPath.length; i++) {
	    
	    int h = 0;
	    char[] c0 = root.child[0].getCharacterAt(root.cellPath[i][0]-2);
	    char[] c1 = root.child[1].getCharacterAt(root.cellPath[i][1]-2);
	    
	    for(int j=0; j<c0.length; j++) {
		sequenceData[h++] += c0[j];
	    }
	    for(int j=0; j<c1.length; j++) {
		sequenceData[h++] += c1[j];
	    }	
	} 

	sitesRemoved = new boolean[sequenceData[0].length()];
	taxaRemoved = new boolean[sequenceNames.length];
	
	taxaNum = sequenceNames.length;
	seqLength = sequenceData[0].length();
	
	minProb = new double[root.cellPath.length];
	for(int i=0; i<root.cellPath.length; i++) {
	    minProb[i] = Math.exp(root.getMinimumInternalPostProbAt(i));
	}

	if(type==1) {		
	    outputNexus(file,1);
	} else if(type==2) {
	    outputFasta(file);
	} else if(type==3) {
	    outputPhylip(file);
	} else if(type==4) {
	    outputMsf(file);
	}

	if(ProAlign.writeRoot) {
	    writeRoot(file);
	}
    }

//
    
    public SaveData(String type, ResultWindow rw) {
	
	ProAlign.log("SaveData");
	
	this.root = rw.root;
	resultString = "vitEnd: "+rw.root.sumViterbiEnd()+", fwdEnd: "+rw.root.sumForwardEnd();

	sitesRemoved = rw.seqcont.getRemovedSites();
	taxaRemoved = rw.seqname.getRemovedTaxa();
	
	sequenceData =  rw.seqcont.textArray;
	sequenceNames = rw.seqname.textArray;

	taxaNum = sequenceNames.length;
	seqLength = sequenceData[0].length();
	
	minProb = new double[rw.root.cellPath.length];
	for(int i=0; i<rw.root.cellPath.length; i++) {
	    minProb[i] = Math.exp(rw.root.getMinimumInternalPostProbAt(i));
	}
		

	OpenFileChooser opf = 
	    new OpenFileChooser(rw,"Save "+type,false);
	String filepath = opf.openFile();
	if(!filepath.equals("")) {

	    UserSettings user = new UserSettings(); 
	    String[] userdata = user.readSettings();
	    ProAlign.folderPath = new File(filepath).getParent();
	    userdata[0] = new File(filepath).getParent();
	    user.writeSettings(userdata);

	    if(type.equals("Nexus")) {
		Object[] options = {"interleaved",
				    "sequential"};
		int n = JOptionPane.showOptionDialog(rw,"Do you want your alignment "+
						     "interleaved or sequential?","Nexus format",
						     JOptionPane.YES_NO_OPTION,
						     JOptionPane.PLAIN_MESSAGE,
						     null,options,options[0]);
		if(n==JOptionPane.YES_OPTION) {
		    outputNexus(filepath,1);
		} else if(n==JOptionPane.NO_OPTION) {
		    outputNexus(filepath,2);
		}
	    } else if(type.equals("Fasta")) {
		outputFasta(filepath);	
	    } else if(type.equals("Phylip")) {
		outputPhylip(filepath);	
	    } else if(type.equals("MSF")) {
		outputMsf(filepath);	
	    }

	    if(ProAlign.writeRoot) {
		writeRoot(filepath);
	    }

	}
    }
    

    // Write data in nexus format.
    //
    public void outputNexus(String file,int type) {
	
	try {
	    
	  OutFile out = new OutFile(file);
	  
	  // Write the intro.
	  out.println("#NEXUS");
	  String date = Calendar.getInstance().getTime().toString();
	  out.println("[ ProAlign - "+date+" ]\n");
	  out.println("Begin data;");
	  out.println("     Dimensions ntax=" + taxaNum + 
		      " nchar=" + seqLength + ";");
	  if(ProAlign.isDna) {
	      out.print("     Format datatype=nucleotide");  
	  } else {
	      out.print("     Format datatype=protein");  
	  }
	  if(type == 1)
	      out.print(" interleave");
	  out.println(" missing=-;");
	  
	  // Write the matrix.
	  out.println("     Matrix");
	  
	  int k = 0, j = 0, max = 0;
	  
	  if (type == 1) { // nexus interleaved
	      
	      while(true) {
		  
		  for(int i = 0; i < taxaNum; ++i) {
		      j = k;
		      String str = "";
		      str = ((sequenceNames[i]+spaceStr).substring(0,21) + "  ");
		      int count = 0;
		      for(int m = 0; m < 100; m++) {
			  if (j >= (seqLength)) {
			      break;
			  }
			  str += (""+sequenceData[i].charAt(j++));
			  if(m > 0)
			      if((m+1) % 20 == 0) {
				  str += (" ");
			      }
		      }
		      out.println(str);
		  }
		  out.println("");
		  if(j >= (seqLength)) {
		      break;
		  }
		  k += 100;
	      }
	      
	  } else { // nexus sequential
	      for(int i = 0; i < taxaNum; ++i) {
		  out.print((sequenceNames[i]+spaceStr).substring(0,21) + "  ");
		  out.println(sequenceData[i]);
	      }
	  }
	  
	  out.println("     ;\nEnd;\n");
	  
	  // Write the paup block.
	  out.println("Begin paup;");
	  
	  // Delete sequnces
	  String str = "     Delete";
	  for (int i = 0; i < taxaNum; i++) {
	      if(taxaRemoved[i]) {
		  str += (" "+(sequenceNames[i]+spaceStr).substring(0,21).trim());
	      }
	  }
	  str += ";";
	  out.println(str);
	  
	  // Exclude sites
	  boolean exc = false;
	  int prev = 0; int start = 0;
	  
	  str = "     Exclude";
	  for (int n = 0; n < sitesRemoved.length; n++) {
	      if(sitesRemoved[n]) {
		  prev = n;
		  if(!exc) {
		      // array starts form 0, Nexus matrix from 1!
		      str += (" " + (n+1));
		      start = n;
		      exc = true;
		  }
	      } else {
		  if(exc) {
		      if (start != prev)
			  str += ("-" + (prev+1));
		      exc = false;
		  }
	      }
	  }
	  if(exc)
	      if (start != prev)
		  str += ("-" + (prev+1));  
	  str += ";";
	  out.println(str);
	  
	  out.println("End;\n");
	  
	  // Write the assumptions block.
	  out.println("Begin assumptions;");
	  str = "     wtset ProAlignWeights (VECTOR) = \n       ";
	  for (int i = 0; i < minProb.length; i++) {
	      str += (roundDoubleToString((double)minProb[i],3)+"     ").substring(0,5)+" ";
	      if(type==1 && i>0 && (i+1)%15==0)
		  str += "\n       ";
	  }
	  str += ";";
	  out.println(str);
	  
	  out.println("End;\n");
	  
	  out.close();
	} catch (IOException e) { }
    }
    
    

    // Write data in fasta format.
    //
    public void outputFasta(String file) {
	
	try {
	    OutFile out = new OutFile(file);
	    
	    for(int i = 0; i < taxaNum; i++) {
		if(!taxaRemoved[i]) {
		    if(ProAlign.isDna) {
			out.println(">DL;" + sequenceNames[i].trim()+"\n");
		    } else {
			out.println(">P1;" + sequenceNames[i].trim()+"\n");
		    }
		    int count = 0;
		    for(int j = 0; j < seqLength; j++) {
			if(!sitesRemoved[j]) {
			    count++;
			    out.print(sequenceData[i].charAt(j));
			    if(count == 50) {
				out.println("");
				count = 0;
			    }
			}
		    }
		    out.println("*");
		}
	    }
	    
	    out.close();
	} catch (IOException e) {
	}

	if(ProAlign.writeMin) {
	    outputWeights(file+".min", false);
	}
	if(ProAlign.writeMean) {
	    outputMean(file+".mean", false);
	}
	if(ProAlign.writeAll) {
	    outputAll(file+".all", false);
	}
    }
    
    
    //
    // Write data table in given file in phylip int. format.
    //
    public void outputPhylip(String file) {
	    
	numTaxa=0;
	for(int i = 0; i < taxaNum; i++) {
	    if(!taxaRemoved[i]) {
		numTaxa++;
	    }
	}
	numSite=0;
	for(int j = 0; j < seqLength; j++) {
	    if(!sitesRemoved[j]) {
		numSite++;
	    }
	}

	try {
	    OutFile out = new OutFile(file);

	    int k = 0;
	    int j = 0;
	    int max = 0;
	    boolean newline = true;

	    out.println(" "+numTaxa+" "+numSite);
	    
	    while(true) {
		for(int i = 0; i < taxaNum; i++) {
		    newline = true;
		    j = k;
		    if(!taxaRemoved[i]) {
			if(k == 0)
			    out.print((sequenceNames[i].trim()+
				       spaceStr).substring(0,10) + "   ");
			else
			    out.print("             ");
			
			int count = 0;
			while(count < 50) {
			    if (j >= seqLength) {
				max = j;
				break;
			    }
			    
			    else if(!sitesRemoved[j]) {
				out.print(sequenceData[i].charAt(j));
				count++;
				if(count % 50 == 0) {
				    out.print("");
				} else if(count % 10 == 0)
				    out.print(" ");
			    }
			    j++;
			}
			max = j;
			out.println("");
		    } else {
			newline = false;
		    }
		}
		if (newline)
		    //out.println("");
		if(max >= seqLength) {
		    break;
		}
		k = max;
		out.println("");
	    }
	    
	    out.close();
	} catch (IOException e) {
	}

	if(ProAlign.writeMin) {
	    outputWeights(file+".min", false);
	}
	if(ProAlign.writeMean) {
	    outputMean(file+".mean", false);
	}
	if(ProAlign.writeAll) {
	    outputAll(file+".all", false);
	}
    }

//

    //
    // Write data table in given file in phylip int. format.
    //
    public void outputMsf(String file) {
	    
	numSite=0;
	for(int j = 0; j < seqLength; j++) {
	    if(!sitesRemoved[j]) {
		numSite++;
	    }
	}

	try {
	    OutFile out = new OutFile(file);

	    int k = 0;
	    int j = 0;
	    int max = 0;
	    boolean newline = true;

	    out.println("PileUp\n");
	    out.print("   MSF: "+numSite);
	    if(ProAlign.isDna) {
		out.println("   Type: N\n");
	    } else {
		out.println("   Type: P\n");
	    }
	    for(int i = 0; i < taxaNum; i++) {
		if(!taxaRemoved[i]) {
		    out.println(" Name: "+(sequenceNames[i].trim()
					+spaceStr).substring(0,21) + "   ");
		}
	    }
	    out.println("\n//\n");

	    while(true) {
		for(int i = 0; i < taxaNum; i++) {
		    newline = true;
		    j = k;
		    if(!taxaRemoved[i]) {
			out.print((sequenceNames[i].trim()
				   +spaceStr).substring(0,21) + "   ");
			
			int count = 0;
			while(count < 50) {
			    if (j >= seqLength) {
				max = j;
				break;
			    }
			    
			    else if(!sitesRemoved[j]) {
				out.print(sequenceData[i].charAt(j));
				count++;
				if(count % 50 == 0) {
				    out.print("");
				} else if(count % 10 == 0)
				    out.print(" ");
			    }
			    j++;
			}
			max = j;
			out.println("");
		    } else {
			newline = false;
		    }
		}
		if (newline)
		    //out.println("");
		if(max >= seqLength) {
		    break;
		}
		k = max;
		out.println("");
	    }
	    
	    out.close();
	} catch (IOException e) {
	}

	if(ProAlign.writeMin) {
	    outputWeights(file+".min", false);
	}
	if(ProAlign.writeMean) {
	    outputMean(file+".mean", false);
	}
	if(ProAlign.writeAll) {
	    outputAll(file+".all", false);
	}
    }	

//


    // Write probabilities in plain text format.
    // 
    public void outputWeights(String file, boolean all) {
	if (all) {
	    double min=0d;
	    int nc=0;
	    for (int i = 0; i < minProb.length; i++) {
		min+=minProb[i];
		nc++;
	    }
	    min=min/nc;

	    double mean=0d;
	    nc=0;
	    for (int i=0; i<seqLength; i++) {
		double[] column = root.getInternalPostProbAt(i);
		double sum = 0d;
		int cn=0;
		for (int j=0; j<column.length; j++) {
		    if(column[j]!=Double.NEGATIVE_INFINITY) {
			sum+=Math.exp(column[j]);
			cn++;
		    }
		}
		sum=sum/cn;
		mean+=sum;
		nc++;
	    }
	    mean=mean/nc;

	    try {
	      OutFile out = new OutFile(file);
	      out.print("# ProAlign: minimum posterior probability.\n");
	      out.print("# "+resultString+", averMin: "+min+", averAll: "+mean+".\n");
	      for (int i = 0; i < minProb.length; i++) {
		  out.print((roundDoubleToString((double)minProb[i],3)+"     ").substring(0,5)+" ");
		  if(i>0 && (i+1)%20==0)
		      out.print("\n");
	      }
	      out.print("\n");
	      out.close();
	    } catch (IOException e) {
	    }
	}else {
	    double min=0d;
	    int nc=0;
	    for (int i = 0; i < minProb.length; i++) {
		if(!sitesRemoved[i]) {
		    min+=minProb[i];
		    nc++;
		}
	    }
	    min=min/nc;

	    double mean=0d;
	    nc=0;
	    for (int i=0; i<seqLength; i++) {
		double[] column = root.getInternalPostProbAt(i);
		double sum = 0d;
		int cn=0;
		for (int j=0; j<column.length; j++) {
		    if(column[j]!=Double.NEGATIVE_INFINITY) {
			sum+=Math.exp(column[j]);
			cn++;
		    }
		}
		sum=sum/cn;
		mean+=sum;
		nc++;
	    }
	    mean=mean/nc;

	    try {
		OutFile out = new OutFile(file);
		out.print("# ProAlign: minimum posterior probability.\n");
		out.print("# "+resultString+", averMin: "+min+", averAll: "+mean+".\n");
		int count = 1;
		for (int i = 0; i < minProb.length; i++) {
		    if(!sitesRemoved[i]) {
			out.print((roundDoubleToString((double)minProb[i],3)+"     ").substring(0,5)+" ");
			if((count%20)==0)
			    out.print("\n");
			count++;
		  }
		}
		out.print("\n");
		out.close();
	    } catch (IOException e) {
		e.printStackTrace();
	    }
	}
    }


    // Write mean probabilities in plain text format.
    // 
    public void outputMean(String file, boolean all) {
	if (all) {
	    double min=0d;
	    int nc=0;
	    for (int i = 0; i < minProb.length; i++) {
		if(!sitesRemoved[i]) {
		    min+=minProb[i];
		    nc++;
		}
	    }
	    min=min/nc;

	    double mean=0d;
	    nc=0;
	    for (int i=0; i<seqLength; i++) {
		double[] column = root.getInternalPostProbAt(i);
		double sum = 0d;
		int cn=0;
		for (int j=0; j<column.length; j++) {
		    if(column[j]!=Double.NEGATIVE_INFINITY) {
			sum+=Math.exp(column[j]);
			cn++;
		    }
		}
		sum=sum/cn;
		mean+=sum;
		nc++;
	    }
	    mean=mean/nc;

	    try {
	      OutFile out = new OutFile(file);
	      out.print("# ProAlign: mean posterior probability.\n");
	      out.print("# "+resultString+", averMin: "+min+", averAll: "+mean+".\n");
	      for (int i=0; i<seqLength; i++) {
		  double[] column = root.getInternalPostProbAt(i);
		  double sum = 0d;
		  int cn=0;
		  for (int j=0; j<column.length; j++) {
		      if(column[j]!=Double.NEGATIVE_INFINITY) {
			  sum+=Math.exp(column[j]);
			  cn++;
		      }
		  }
		  sum=sum/cn;
		  out.print((roundDoubleToString((double)sum,3)+"     ").substring(0,5)+" ");
		  if(i>0 && (i+1)%20==0)
		      out.print("\n");
	      }
	      out.print("\n");
	      out.close();
	    } catch (IOException e) {
	    }
	}else {
	    double min=0d;
	    int nc=0;
	    for (int i = 0; i < minProb.length; i++) {
		if(!sitesRemoved[i]) {
		    min+=minProb[i];
		    nc++;
		}
	    }
	    min=min/nc;

	    double mean=0d;
	    nc=0;
	    for (int i=0; i<seqLength; i++) {
		if(!sitesRemoved[i]) {
		    double[] column = root.getInternalPostProbAt(i);
		    double sum = 0d;
		    int cn=0;
		    for (int j=0; j<column.length; j++) {
			if(column[j]!=Double.NEGATIVE_INFINITY) {
			    sum+=Math.exp(column[j]);
			    cn++;
			}
		    }
		    sum=sum/cn;
		    mean+=sum;
		    nc++;
		}
	    }
	    mean=mean/nc;
	    try {
		OutFile out = new OutFile(file);
		out.print("# ProAlign: mean posterior probability.\n");
		out.print("# "+resultString+", averMin: "+min+", averAll: "+mean+".\n");
		int count = 1;
		for (int i = 0; i < seqLength; i++) {
		    if(!sitesRemoved[i]) {
			double[] column = root.getInternalPostProbAt(i);
			double sum = 0d;
			int cn=0;
			for (int j=0; j<column.length; j++) {
			    if(column[j]!=Double.NEGATIVE_INFINITY) {
				sum+=Math.exp(column[j]);
				cn++;
			    }
			}
			sum=sum/cn;
			out.print((roundDoubleToString((double)sum,3)+"     ").substring(0,5)+" ");
			if((count%20)==0)
			    out.print("\n");
			count++;
		  }
		}
		out.print("\n");
		out.close();
	    } catch (IOException e) {
		e.printStackTrace();
	    }
	}
    }


    // Write all probabilities in plain text format.
    // 
    public void outputAll(String file, boolean all) {
	if (all) {
	    try {
	      OutFile out = new OutFile(file);
	      out.print("# ProAlign: all posterior probability.\n");
	      out.print("# "+resultString+".\n");
	      String[] nodeNames = root.getInternalNames();
	      for (int j=0; j<root.getNumChild()-1; j++) {
		  out.print("# "+nodeNames[j]+".\n");
		  for (int i=0; i<seqLength; i++) {
		      double[] column = root.getInternalPostProbAt(i);
		      out.print((roundDoubleToString((double)Math.exp(column[j]),3)+"     ").substring(0,5)+" ");
		      if(i>0 && (i+1)%20==0)
			  out.print("\n");
		  }
		  out.print("\n");
	      }
	      out.print("\n");
	      out.close();
	    } catch (IOException e) {
	    }
	}else {
	    try {
		OutFile out = new OutFile(file);
		out.print("# ProAlign: mean posterior probability.\n");
		out.print("# "+resultString+".\n");
		String[] nodeNames = root.getInternalNames();
		for (int j=0; j<root.getNumChild()-1; j++) {
		    out.print("# "+nodeNames[j]+".\n");
		    int count = 1;
		    for (int i = 0; i < seqLength; i++) {
			if(!sitesRemoved[i]) {
			    double[] column = root.getInternalPostProbAt(i);			    
			    out.print((roundDoubleToString((double)Math.exp(column[j]),3)+"     ").substring(0,5)+" ");
			    if((count%20)==0)
				out.print("\n");
			    count++;
			}
		    }
		    out.print("\n");
		}
		out.print("\n");
		out.close();
	    } catch (IOException e) {
		e.printStackTrace();
	    }
	}
    }

    // Write root character probabilities
    // 
    public void writeRoot(String file) {

	try {
	    OutFile out = new OutFile(file+".root");

	    for(int i=0; i<root.charProb.length; i++) {
		out.print(i+1);
		for(int j=0; j<root.charProb[i].length; j++) {
		    out.print(" "+root.charProb[i][j]);
		}
		out.println();
	    }

	    out.close();
	} catch (IOException e) {
	    e.printStackTrace();
	}
    }

    // Round a value to certain precision
    //
    public String roundDoubleToString(double val, int prec) {
	String full = ""+val;
	if(full.indexOf('.')>-1) {
	    String begin = full.substring(0,full.indexOf('.'));
	    String end = full.substring(full.indexOf('.')+1);
	    if(end.length()>prec) {
		char num = end.charAt(prec);
		if(num=='0'||num=='1'||num=='2'||num=='3'||num=='4') {
		    full = begin+"."+end.substring(0,prec);
		    
		    // last one is greater than 4 -> rounded up
		} else {
		    char[] digit = new char[prec];
		    for(int i=0; i<prec; i++) {
			digit[i] = end.charAt(i);
		    }
		    
		    int add = 1;
		    for(int i=prec-1; i>=0; i--) {
			if(digit[i]=='9') {
			    add = 1;
			    digit[i]='0';
			} else {
			    digit[i]= (char)((int)digit[i]+add);
			    add = 0;
			    break;
			}		       
		    }
		    begin = ""+(new Integer(begin).intValue()+add);
		    end = new String();
		    for(int i=0; i<prec; i++) {
			end += digit[i];
		    }
		    full = begin+"."+end;
		}
	    }
	}
	return full;
    }
}
















