/*
 *                    BioJava development code
 *
 * This code may be freely distributed and modified under the
 * terms of the GNU Lesser General Public Licence.  This should
 * be distributed with the code.  If you do not have a copy,
 * see:
 *
 *      http://www.gnu.org/copyleft/lesser.html
 *
 * Copyright for this code is held jointly by the individual
 * authors.  These should be listed in @author doc comments.
 *
 * For more information on the BioJava project and its aims,
 * or to join the biojava-l mailing list, visit the home page
 * at:
 *
 *      http://www.biojava.org/
 *
 */

package org.biojava.bio.program.hmmer;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.Iterator;
import java.util.StringTokenizer;

import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionFactory;
import org.biojava.bio.dp.DotState;
import org.biojava.bio.dp.EmissionState;
import org.biojava.bio.dp.State;
import org.biojava.bio.seq.ProteinTools;
import org.biojava.bio.seq.io.SymbolTokenization;
import org.biojava.bio.symbol.Alphabet;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.Symbol;



/** A class for parsing in Hmmer markov models from HMM_ls files generated by HMMER training
 * note that this class is still currently experimental. 
 * @author Lachlan Coin
 */
public class HmmerProfileParser{

		protected Alphabet alph = ProteinTools.getAlphabet();
	
    /** Returns a profile HMM representing the core HMMER hmm
      * @param inputfile the file which contains the Profile HMM data, as output by HMMER - e.g. HMM_ls
     */
    public static HmmerProfileHMM parse(File inputfile){
	HmmerProfileParser hmmP = new HmmerProfileParser(inputfile.toString());
	hmmP.parseModel(inputfile);
	hmmP.setProfileHMM();
	return hmmP.getModel();
    }

    /** Returns the full markov model - including the core model + J,C,N loop states.
     *  @param inputfile the file which contains the Profile HMM data, as output by HMMER - e.g. HMM_ls
     */    
    public static FullHmmerProfileHMM parseFull(File inputfile){
	HmmerProfileParser hmmP = new HmmerProfileParser(inputfile.toString());
	hmmP.parseModel(inputfile);
	hmmP.setProfileHMM();
	hmmP.initialiseFullProfileHMM();
	hmmP.setFullProfileHMM();
	return hmmP.getFullModel();
    }

   
    protected String domain1;


    protected HmmerProfileParser(String domain){
	this.domain1 = domain;
    }
    
    protected HmmerProfileHMM initialiseProfileHMM(int len){
	    try{
	    DistributionFactory matchFactory = DistributionFactory.DEFAULT;
	    DistributionFactory insertFactory = DistributionFactory.DEFAULT;
	    return new HmmerProfileHMM(alph, len, matchFactory, insertFactory, domain1);
	    }
	    catch(Throwable t){
		t.printStackTrace();
		return null;
	    }
	}

    protected HmmerModel hmm;
    private static final double sumCheckThreshold = 0.001;

    public HmmerProfileHMM getModel(){
	return hmm.hmm;
    }

    
    FullHmmerProfileHMM getFullModel(){
	return hmm.hmm_full;
    }
    
    void initialiseFullProfileHMM(){
	hmm.initialiseFullProfileHMM();
    }
   

     public void setProfileHMM(){
	hmm.setProfileHMM();
    }
    
     void  setFullProfileHMM(){
	hmm.setFullProfileHMM();
    }
    

     public void parseModel(File inputFile){
	System.out.println("Parsing model "+inputFile);
	try{
	    BufferedReader in = new BufferedReader(new FileReader(inputFile));
	    boolean inModel=false;
	    int seq_pos=1;
	    int rel_pos=0;
	    String s = new String();
	    while((s = in.readLine())!= null){
		if(s.startsWith("//")) break;
		if(!inModel){		
		    if(s.startsWith("LENG")) {
			int[] a = parseString(s.substring(5),1);
			hmm = new HmmerModel(a[0]);
		    }
		    else if(s.startsWith("NULE"))
			hmm.setNullEmissions(s.substring(5));
		    else if(s.startsWith("NULT"))
			hmm.setNullTransitions(s.substring(5));
		    else if(s.startsWith("XT"))
			hmm.setSpecialTransitions(s.substring(5));
		    else if(s.startsWith("HMM ")){
			inModel=true;
			hmm.setAlphList(s.substring(7));
			in.readLine();
			hmm.setBeginTransition(in.readLine());
		    }
		}
		else{
		    if(rel_pos==0){
			hmm.setEmissions(s.substring(7), seq_pos);
		    }
		    else if(rel_pos==1 && seq_pos==1){
			hmm.setInsertEmissions(s.substring(7));
		    }
		    else if(rel_pos==2){
			hmm.setTransitions(s.substring(7), seq_pos);
		    }
		    rel_pos++;
		    if(rel_pos==3){
			rel_pos=0;
			seq_pos++;
		    }
		}		   
	    }
	    in.close();
	}
	catch(Throwable t){
          t.printStackTrace();
	}
    }


    



    static int[] parseString(String s, int len){
	String[] s1 = parseStringA(s,len);
	int[] s2 = new int[len];
	for(int i=0; i<s1.length; i++){
	    if(s1[i].indexOf("*")!= -1) s2[i] = Integer.MIN_VALUE;
	    else s2[i] = Integer.parseInt(s1[i]);
	}
	return s2;
    }

    static String[] parseStringA(String s, int len){
	String[] s2 = new String[len];
	StringTokenizer st = new StringTokenizer(s);
	int i=0;
	while(st.hasMoreTokens() && i<len){
	    s2[i] = st.nextToken(); 
	    i++;
	}
	return s2;
    }


/** An intermediate class to store the parsed data */
    class HmmerModel{
	/** Maps the null emission probabilities to symbols */
	int[] nullEmissions;
	/** Maps the null transition probabilities */
	int[] nullTransitions;

	/**A map of emission probabilities along the sequence. */
	int[][] emissions;
	
	int[] insertEmissions;


	int[][] transitions;
	
	int[] beginTransition;

	int[] specialTransitions;

	Symbol[] alphList;
;
	HmmerProfileHMM hmm;
	FullHmmerProfileHMM hmm_full;

	HmmerModel(int length){
	    System.out.println("Constructing base model");
	    nullEmissions = new int[20];
	    nullTransitions = new int[2];
	    emissions = new int[length][20];
	    insertEmissions = new int[20];
	    specialTransitions = new int[8];
	    transitions = new int[length+1][9];
	    alphList = new Symbol[21];
	    hmm = initialiseProfileHMM(emissions.length);

	}

	void setAlphList(String s){
	    try{
		String[] list = parseStringA(s,20);
		SymbolTokenization tokenizer = alph.getTokenization("token");
		for(int i=0; i<list.length;i++){
		    alphList[i] = tokenizer.parseToken(list[i]);
		}
		alphList[list.length] = tokenizer.parseToken("U");
	    }
	    catch(Throwable t){
		t.printStackTrace();
	    }
	}

	void setEmissions(String s, int pos){
	    emissions[pos-1] = parseString(s,20);
	}

	void setNullEmissions(String s){
	    nullEmissions = parseString(s,20);

	}

	void setInsertEmissions(String s){
	    insertEmissions = parseString(s,20);
	}

	void setNullTransitions(String s){
	    nullTransitions = parseString(s,2);
	}

	void setTransitions(String s, int pos){
	   transitions[pos] =  parseString(s,9);
	}
	
	void setBeginTransition(String s){
	    transitions[0]= parseString(s,3);
	}

	void setSpecialTransitions(String s){
	    specialTransitions = parseString(s,8);
	}



	
	
	void initialiseFullProfileHMM(){
	    try{
		hmm_full = new FullHmmerProfileHMM(hmm);
	    }
	    catch(Throwable t){
		t.printStackTrace();
	    }
	}
       

	private void  validateDistributionSum(Distribution dist) throws Exception{
	    Iterator iter = ((FiniteAlphabet)dist.getAlphabet()).iterator();
	    double sum=0.0;
	    while(iter.hasNext()){
		Symbol to = (Symbol) iter.next();
		sum += dist.getWeight(to);
		//System.out.println(to.getName()+" "+dist.getWeight(to));
	    }
	    //System.out.println("//");
	    validateSum(sum);
	}

	private void  addProbability(Distribution dist, State state, double prob) throws Exception{
	    if(Double.isNaN(dist.getWeight(state)))
		dist.setWeight(state,0);
	    Iterator iter = ((FiniteAlphabet)dist.getAlphabet()).iterator();
	    while(iter.hasNext()){
		Symbol to = (Symbol) iter.next();
		double currentP = dist.getWeight(to);
		dist.setWeight(to,currentP*(1-prob));
	    }
	    dist.setWeight(state, dist.getWeight(state)+prob);
	    validateDistributionSum(dist);
	}

	private void checkTransitionSum() throws Exception{
	    for (int i=0; i<=hmm.columns();i++){
		validateDistributionSum(hmm.getWeights(hmm.getMatch(i)));
		if(i>0 && i<hmm.columns()){
		    validateDistributionSum(hmm.getWeights(hmm.getInsert(i)));
		    validateDistributionSum(hmm.getWeights(hmm.getDelete(i)));
		}
	    }
	}

	private void validateSum(double sum) throws Exception{
	    if(Math.abs(sum-1.0)>sumCheckThreshold)
		throw new Exception("Distribution does not sum to 1.  Sums to "+sum);
	}

	/** Modifies HMM search for partial hits, by dividing probability by 2 at
	 *  each point and adding transition to end state
	 */
	void addProfileHMMTransitions() throws Exception{
	    for(int i=1; i<=hmm.columns(); i++){
		if(i>1){
		    hmm.createTransition(hmm.magicalState(), hmm.getMatch(i));
		}
		if(i<hmm.columns()){
		    hmm.createTransition(hmm.getMatch(i), hmm.magicalState());
		}
	    }
	}


	/** Modifies HMM search for partial hits, by dividing probability by 2 at
	 *  each point and adding transition to end state
	 */
	void modifyProbabilities() throws Exception{
	    for(int i=1; i<=hmm.columns(); i++){
		if(i>1){
		    addProbability(hmm.getWeights(hmm.magicalState()), hmm.getMatch(i), 0.5);
		}
		if(i<hmm.columns()){
		    addProbability(hmm.getWeights(hmm.getMatch(i)), hmm.magicalState(), 0.5);
		}
	    }
	}
	    

	void setBeginEnd() throws Exception{
	    Distribution dist = hmm.getWeights(hmm.magicalState());
	    for(int i=1; i<=hmm.columns(); i++){
		EmissionState match = hmm.getMatch(i);
		Distribution match_dist = hmm.getWeights(match);
		dist.setWeight(match, convertToProb(transitions[i][7]));
		match_dist.setWeight(hmm.magicalState(),convertToProb(transitions[i][8]));
	    }
	}
	
	void setFullProfileHMM(){
	    try{
		Distribution dist = hmm_full.getWeights(hmm_full.magicalState());
		dist.setWeight(hmm_full.nState(), 1.0);

		dist = hmm_full.getWeights(hmm_full.nState());
		dist.setWeight(hmm_full.hmm(), convertToProb(specialTransitions[0]));
		dist.setWeight(hmm_full.nState(), convertToProb(specialTransitions[1]));

		dist = hmm_full.getWeights(hmm_full.hmm());
		dist.setWeight(hmm_full.cState(), convertToProb(specialTransitions[2]));
		dist.setWeight(hmm_full.jState(), convertToProb(specialTransitions[3]));

		dist = hmm_full.getWeights(hmm_full.cState());
		dist.setWeight(hmm_full.magicalState(), convertToProb(specialTransitions[4]));
		dist.setWeight(hmm_full.cState(), convertToProb(specialTransitions[5]));

		dist = hmm_full.getWeights(hmm_full.jState());
		dist.setWeight(hmm_full.hmm(), convertToProb(specialTransitions[6]));
		dist.setWeight(hmm_full.jState(), convertToProb(specialTransitions[7]));
	    }
	    catch(Throwable t){
		t.printStackTrace();
	    }
	}


	void setProfileHMM(){
	    try{
	    for(int i=0; i<=hmm.columns(); i++){
		EmissionState match = hmm.getMatch(i);
		Distribution dist = hmm.getWeights(match);
		if(i<hmm.columns()){
		    dist.setWeight(hmm.getMatch(i+1), 
				   convertToProb(transitions[i][0]));
		    if(i>=1){
		    	dist.setWeight(hmm.getInsert(i), 
		    		       convertToProb(transitions[i][1]));
		    }
		    dist.setWeight(hmm.getDelete(i+1), 
				   convertToProb(transitions[i][2]));
		}
		else{
		    dist.setWeight(hmm.magicalState(),1.0);
		}
	    }
	    for(int i=1; i<hmm.columns(); i++){
	    	EmissionState insert = hmm.getInsert(i);
	    	Distribution dist = hmm.getWeights(insert);
	    	dist.setWeight(hmm.getMatch(i+1), 
	    		       convertToProb(transitions[i][3]));
	    	dist.setWeight(insert, 
	    		       convertToProb(transitions[i][4]));
	    }
	    for(int i=1; i<hmm.columns(); i++){
		DotState delete = hmm.getDelete(i);
		Distribution dist = hmm.getWeights(delete);
		dist.setWeight(hmm.getMatch(i+1), 
			       convertToProb(transitions[i][5]));
		dist.setWeight(hmm.getDelete(i+1), 
			       convertToProb(transitions[i][6]));
	    }
	    setBeginEnd();
	    checkTransitionSum();
	    // setting emission probabilities

	    Distribution insertEmission = hmm.getInsert(1).getDistribution();
	    Distribution nullModel = DistributionFactory.DEFAULT.createDistribution(alph);
	    for (int j=0; j<alphList.length; j++){
		double prob;
		double null_prob;
		if(j<alphList.length-1){
	    null_prob = convertToProb(nullEmissions[j],0.05);	
	    prob = convertToProb(insertEmissions[j], null_prob);
	
		}
		else{
		    prob  = 0.0;
		    null_prob = 0.0;
		}
		insertEmission.setWeight(alphList[j],prob);
		nullModel.setWeight(alphList[j],null_prob);
	    }
	    insertEmission.setNullModel(nullModel);
	    validateDistributionSum(insertEmission);
	    //System.out.println("NULL MODEL!!!!!");
	    validateDistributionSum(nullModel);

	    for(int i=1; i<=hmm.columns(); i++){
		Distribution matchEmission = hmm.getMatch(i).getDistribution();
		if(i>1 && i<hmm.columns())
		    hmm.getInsert(i).setDistribution(insertEmission);
		for (int j=0; j<alphList.length; j++){
		double prob;
		if(j<alphList.length-1){
			double  null_prob = convertToProb(nullEmissions[j],0.05);
		    prob = convertToProb(emissions[i-1][j],null_prob);
		}
		else prob  = 0.0;
		    matchEmission.setWeight(alphList[j],prob);
		}
		validateDistributionSum(matchEmission);
		matchEmission.setNullModel(nullModel);
		//System.out.println("NULL MODEL!!!!!");
		validateDistributionSum(matchEmission.getNullModel());
	    }
	    //modifyProbabilities();
	    }
	     catch(Throwable t){
		t.printStackTrace();
	    }
	}	    

	private double convertToProb(int score){
	    double result=0.0;
	    if(score!=Integer.MIN_VALUE){
		result = 1*Math.pow(2.0,((double) score/1000));
	    }
	    return result;
	}

	private double convertToProb(int score, double nullprob){
	    double result=0.0;
	    if(score!=Integer.MIN_VALUE){
		//double background = 0.05*Math.pow(2.0,((double) nullscore/1000));
		//result = background;
		result = nullprob*Math.pow(2.0,((double) score/1000));
		//System.out.println(score+ " "+nullscore+" "+result);
	    }
	    return result;
	}


    }
}



	

