/**
 * run a 2 state 2nd order HMM in a change-point detection framework
 * to optimize the predicted boundaries using BioJava libraries
 * 
 * @author George Vernikos <gsv@sanger.ac.uk>
 * 
 * For more information on the BioJava project visit: http://www.biojava.org/
*/

/*
LICENSE

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

*/

import java.io.*;
import org.biojava.bio.symbol.*;
import org.biojava.bio.seq.*;
import org.biojava.bio.seq.io.*;
import org.biojava.bio.dp.*;
import org.biojava.bio.*;
import org.biojava.bio.seq.db.*;
import org.biojava.bio.seq.impl.*;
import org.biojava.bio.dist.*;
import org.biojava.utils.*;
import java.util.*;

class ChangepointRight{

public static SymbolList seqL;
public static int order;
public static int flatOrRandom;
public static int trainOrUntrain;
public static Distribution dist;
public static int duration;
public static ModelTrainer mt;
public static int transition_point=0;
public static int count=0; 

	//make alphabets
   	static FiniteAlphabet DnaAlphabet = DNATools.getDNA();
	
	public static void main (String args[]) throws Exception{

	if(args.length != 5) {
	throw new Exception("Use: sequence.fa order.int flatD.bin trainableTrans.bin duration.int");
	}					
	    
    	try{
		
		File seqFile = new File(args[0]);
		order = Integer.parseInt(args[1]);
		flatOrRandom = Integer.parseInt(args[2]);
		trainOrUntrain = Integer.parseInt(args[3]);
		duration = Integer.parseInt(args[4]);
		
		if((flatOrRandom != 0) & (flatOrRandom != 1)) {
		throw new Exception("Use flatD.bin: only binary i.e. 0 or 1: . . 1/0 . .");
		}
		if((trainOrUntrain != 0) & (trainOrUntrain != 1)) {
		throw new Exception("Use trainableTrans.bin: only binary i.e. 0 or 1: . . . 1/0 .");
		}
	
		SymbolTokenization rParser = DnaAlphabet.getTokenization("token");
 	
		SequenceBuilderFactory sbFact = new FastaDescriptionLineParser.Factory(SimpleSequenceBuilder.FACTORY);
    		FastaFormat fFormat = new FastaFormat();
		
		SequenceIterator seqI = new StreamReader(new FileInputStream(seqFile),
                                                   fFormat,
                                                   rParser,
                                                   sbFact);
        	seqI.hasNext();
      
        	Sequence seq2 = seqI.nextSequence();
        	SequenceDB seqs = new HashSequenceDB();
        	seqL = seq2;
	
		MarkovModel island = createModel();
		DP dp=DPFactory.DEFAULT.createDP(island);
	
	        Sequence seq = new SimpleSequence(
                    SymbolListViews.orderNSymbolList(seq2, order),
                    null,
                    seq2.getName() + "-o" + order,
                    Annotation.EMPTY_ANNOTATION
                );
        
        	seqs.addSequence(seq);
	
		TrainingAlgorithm ta = new BaumWelchTrainer(dp);
      	   	
        	ta.train(
            		seqs,
            		0.01,
            	new StoppingCriteria() {
               		public boolean isTrainingComplete(TrainingAlgorithm ta) {
                
                	try {
			// XmlMarkovModel.writeModel(ta.getDP().getModel(), System.out);                
			//out2.write(ta.getCycle() + "\t" + ta.getCurrentScore() + "\n");
                    	}catch (Exception ex) {ex.printStackTrace();}
                    	//System.out.println(ta.getCycle() + "\t" + ta.getCurrentScore());
                    	//return (ta.getCycle() >=2);
                    	return Math.abs(ta.getLastScore() - ta.getCurrentScore()) < 0.001;
                	}
            	} 
        	);
	
		//Viterbi
	
		SymbolList [] rl = {SymbolListViews.orderNSymbolList(seq2, order)};
        
       		StatePath statePath = dp.viterbi(rl, ScoreType.PROBABILITY);
	
		for(int i = 0; i <= statePath.length() / 60; i++) {
    	        
	        	for(int j = i*60; j < Math.min((i+1)*60, statePath.length()); j++) {
	            		//System.out.print(statePath.symbolAt(StatePath.STATES, j+1).getName().charAt(0));
	            		char state=statePath.symbolAt(StatePath.STATES, j+1).getName().charAt(0);
	       		 	count++;
				//it prints the states in binary mode for art user_graph
				if(state == 'a'){
				//out.write("0 1");
				}
				else{
				transition_point=count;
				//out.write("1 0");			
				}
			
	        	}
	        	
       	 	}
        
        	System.out.print(transition_point + " " + statePath.getScore());
	
		}catch (Exception e) {
      		e.printStackTrace();
    	}
}

    //creates the model
    public static MarkovModel createModel() {
    	
    	List l = Collections.nCopies(order, DNATools.getDNA());
	Alphabet alpha = AlphabetManager.getCrossProductAlphabet(l);
       
    	int [] advance = { 1 };
        Distribution typicalD;
    	Distribution atypicalD;
    
    	try{
    	    	
      		//check if higher order; else normal dist
      		if(order >1){	
      		typicalD = OrderNDistributionFactory.DEFAULT.createDistribution(alpha);
      		atypicalD = OrderNDistributionFactory.DEFAULT.createDistribution(alpha);
      		}
      		else{
      		typicalD = DistributionFactory.DEFAULT.createDistribution(alpha);
      		atypicalD = DistributionFactory.DEFAULT.createDistribution(alpha);
      		}


    	}catch (Exception e){
      	throw new AssertionFailure("Can't create distributions", e);
    	}
    
    	EmissionState typicalS = new SimpleEmissionState("typical", Annotation.EMPTY_ANNOTATION, advance, typicalD);
    	EmissionState atypicalS = new SimpleEmissionState("atypical", Annotation.EMPTY_ANNOTATION, advance, atypicalD);

    	SimpleMarkovModel island = new SimpleMarkovModel(1, alpha, "Island");

    	try{
      		island.addState(typicalS);
		island.addState(atypicalS);
    	}catch (Exception e){
      	throw new AssertionFailure("Can't add states to model", e);
    	}

    	//set up transitions between states
    	try {
      		island.createTransition(island.magicalState(),typicalS);
      		island.createTransition(island.magicalState(),atypicalS);
      		island.createTransition(typicalS,island.magicalState());
     		island.createTransition(atypicalS,island.magicalState());
      		island.createTransition(typicalS,atypicalS);
      		island.createTransition(atypicalS,typicalS);
      		island.createTransition(typicalS,typicalS);
      		island.createTransition(atypicalS,atypicalS);
    	}catch (Exception e){
      	throw new AssertionFailure("Can't create transitions", e);
    	}

    	//set up emission probabilities
    	try {
        	SymbolList highOrderSeq = SymbolListViews.orderNSymbolList (seqL, order);
		Hashtable symbol= new Hashtable();
        	
	       	for (Iterator i = highOrderSeq.iterator(); i.hasNext(); ) {
     			Symbol sym = (Symbol) i.next();
             
			if(!symbol.containsKey(sym)){
			//uniform weights for atypical emmision probs
         		atypicalD.setWeight(sym,0.25);	
       			typicalD.setWeight(sym,	0.25);
            		symbol.put(sym, new Integer(1));
       			}
       	  	}
	
		if(flatOrRandom == 0){
		//it randomizes the atypical emission probs
        	DistributionTools.randomizeDistribution(atypicalD);
        	DistributionTools.randomizeDistribution(typicalD);
		}

    	}catch (Exception e) {
      	throw new AssertionFailure("Can't set emission probabilities", e);
    	}

    	//set up transition scores.
    	try {
		{		
		//if user option =1 then it trains ; if 0 then untrained
      		if(trainOrUntrain ==0){ 
		//it keeps the transition probs untrainable
      		dist = new UntrainableDistribution (island.transitionsFrom(island.magicalState()));
      		}
      		else{
      		dist = island.getWeights(island.magicalState());
      		}
		dist.setWeight(typicalS,  	    1.0);
		//since it will always start at start at state typicalS
		dist.setWeight(atypicalS, 	    0.0);	
		island.setWeights(island.magicalState(), dist);	
       		}
    
  		{
		// always trainable
		dist = island.getWeights(typicalS);
      		float T_A = (float)1/duration;
      		float T_T = (float)1-T_A;
		//1/region = 1/7500
      		dist.setWeight(atypicalS,             T_A); 
		//1-1/7500
		dist.setWeight(typicalS,              T_T);
		//zero since it will always end at atypical 
		dist.setWeight(island.magicalState(), 0.0);	   	
		island.setWeights(typicalS, dist);
  		}
  
  		{    	
		//always untrainable
      		dist = new UntrainableDistribution (island.transitionsFrom(atypicalS)); 
		//when it changes it persists for ever.
     		dist.setWeight(typicalS,              0.0000000000000000000000000000001);
     		dist.setWeight(atypicalS,             0.9999);
		//it was 0.0001  but it throwed NaNs
		dist.setWeight(island.magicalState(), 0.0000999999999999999999999999999);
		island.setWeights(atypicalS, dist);
  		}    
    	}catch (Exception e) {
      	throw new AssertionFailure("Can't set transition probabilities", e);
    	}
  
    return island;
  }

}

