/*
 * TrainingInfoTest.java
 * 
 * Copyright 2006 Michigan State University Board of Trustees
 *
 * Created on September 19, 2003, 10:23 AM
 */

package edu.msu.cme.rdp.classifier.rrnaclassifier;

import edu.msu.cme.rdp.classifier.ClassificationResult;
import edu.msu.cme.rdp.classifier.Classifier;
import edu.msu.cme.rdp.classifier.GenusWordConditionalProb;
import edu.msu.cme.rdp.classifier.HierarchyTree;
import edu.msu.cme.rdp.classifier.RankAssignment;
import edu.msu.cme.rdp.classifier.TrainingInfo;
import edu.msu.cme.rdp.classifier.utils.ClassifierSequence;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import junit.framework.Test;
import junit.framework.TestCase;
import junit.framework.TestSuite;

/**
 * A test class for TrainingInfo.
 * @author wangqion
 */
public class TrainingInfoTest extends TestCase {
    
    public TrainingInfoTest(java.lang.String testName) {
        super(testName);
    }
    
    public static void main(java.lang.String[] args) {
        junit.textui.TestRunner.run(suite());
    }
    
    public static Test suite() {
        TestSuite suite = new TestSuite(TrainingInfoTest.class);
        
        return suite;
    }
    
    /** Test of createTree method, of class classification.TrainingInfo. */
    public void testCreateTree() {
        System.out.println("testCreateTree");
        
    }
    
    /** Test of createLogWordPriorArr method, of class classification.TrainingInfo. */
    public void testCreateLogWordPriorArr() throws Exception{
        System.out.println("testCreateLogWordPriorArr");
        InputStream dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testLogWordPrior.txt");
        Reader in =  new InputStreamReader( dstream );
        
        TrainingInfo train = new TrainingInfo();
        train.createLogWordPriorArr(in);
        float wordPrior = train.getLogWordPrior(10);
        assertEquals(wordPrior, -2.77, 0.1);
        wordPrior = train.getLogWordPrior(55);
        assertEquals(wordPrior, -1.67, 0.1);
        
        // test getWordPairPriorDiff()
        wordPrior = train.getLogWordPrior(24573);
        assertEquals(wordPrior, -2.77, 0.01);
        float revWordPrior = train.getLogWordPrior(10912);
        assertEquals(revWordPrior, -1.67, 0.01);        
        
        float wordPriorDiff = train.getWordPairPriorDiff(24573);
        assertEquals(wordPriorDiff, (wordPrior - revWordPrior), 0.01);
        float revWordPriorDiff = train.getWordPairPriorDiff(10912);
        assertEquals(revWordPriorDiff, (revWordPrior - wordPrior), 0.01);
        assertEquals( (wordPriorDiff + revWordPriorDiff), 0, 0.01);
        
    }
    
    /** Test of createGenusWordProbLis method, of class classification.TrainingInfo. */
    public void testCreateGenusWordProbList() throws Exception{
        System.out.println("testCreateGenusWordConditionalProbList");
        InputStream dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testGenus_probList.txt");
        Reader in =  new InputStreamReader( dstream );
        TrainingInfo train = new TrainingInfo();
        train.createGenusWordProbList(in);
        GenusWordConditionalProb gProb = train.getWordConditionalProbObject(8);
        int genusIndex = gProb.getGenusIndex();
        float prob = gProb.getProbability();
        assertEquals(genusIndex, 4);
        assertEquals(prob, -0.15, 0.1);
        
        gProb = train.getWordConditionalProbObject(1865);
        genusIndex = gProb.getGenusIndex();
        prob = gProb.getProbability();
        assertEquals(genusIndex, 0);
        assertEquals(prob, -0.5, 0.1);
        
    }
    
    
    /** Test of makeProbIndexArr method, of class classification.TrainingInfo. */
    public void testCreateProbIndexArr() throws Exception {
        System.out.println("testCreateProbIndexArr");
        InputStream dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testProbIndex.txt");
        Reader in =  new InputStreamReader( dstream );
        TrainingInfo train = new TrainingInfo();
        train.createProbIndexArr(in);
        int start = train.getStartIndex(100);
        int stop = train.getStopIndex(65535);
        assertEquals(start, 1);
        assertEquals(stop, 1866);
        
    }
    
    /** Test of getRootTree method, of class classification.TrainingInfo. */
    public void testCreateClassifier() throws Exception{
        System.out.println("testCreateClassifier");
        TrainingInfo train = new TrainingInfo();
        
        InputStream dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testGenus_probList.txt");
        Reader in =  new InputStreamReader( dstream );
        train.createGenusWordProbList(in);
        
        dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/test_bergeyTrainingTree.xml");
        in =  new InputStreamReader( dstream );
        train.createTree(in);
        
        dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testProbIndex.txt");
        in =  new InputStreamReader( dstream );
        train.createProbIndexArr(in);
        
        dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testLogWordPrior.txt");
        in =  new InputStreamReader( dstream );
        train.createLogWordPriorArr(in);
        
        int genusNodeListSize = train.getGenusNodeListSize();
        assertEquals(genusNodeListSize, 6);
        
        HierarchyTree genusNode = train.getGenusNodebyIndex(3);
        assertEquals(genusNode.getGenusIndex(), 3);
        assertEquals(genusNode.getName(), "Pseudomonas");
        int leaveCount = genusNode.getLeaveCount();
        float logLeaveCount = train.getLogLeaveCount(1);
        assertEquals((double)leaveCount, Math.exp(logLeaveCount) -1, 0.1);
        
        HierarchyTree rootTree = train.getRootTree();
        assertEquals(rootTree.getName(), "Bacteria");
        assertNotNull( rootTree.getSubclasses());
        
        Classifier aClassifier = train.createClassifier();
        
        //test addConfidence()
        
        HashMap<HierarchyTree, RankAssignment> determinedMap = new HashMap <HierarchyTree, RankAssignment>();
		HierarchyTree determinedGenusNode = train.getGenusNodebyIndex(4);
		HierarchyTree Enterobacteriales = determinedGenusNode.getParent().getParent();
		HierarchyTree Gammaproteobacteria = Enterobacteriales.getParent();
		
		HierarchyTree aNode = determinedGenusNode;
		while (aNode != null){			
			determinedMap.put(aNode, new RankAssignment(aNode, 0));
			aNode = aNode.getParent();
		}
		
		HierarchyTree bestNode = train.getGenusNodebyIndex(1);  // Vibrio
		aClassifier.addConfidence(bestNode, determinedMap);
		assertEquals( determinedMap.get(determinedGenusNode).getConfidence(), 0.0f, 0.0001);	//Enterobacter
		assertEquals( determinedMap.get(Enterobacteriales).getConfidence(), 0.0f, 0.0001);  //Enterobacteriales
		assertEquals( determinedMap.get(Gammaproteobacteria).getConfidence(), 1.0f, 0.0001);  //Gammaproteobacteria
        
		bestNode = train.getGenusNodebyIndex(4);   //Enterobacter
		aClassifier.addConfidence(bestNode, determinedMap);
		assertEquals( determinedMap.get(determinedGenusNode).getConfidence(), 1.0f, 0.0001);	//Enterobacter
		assertEquals( determinedMap.get(Enterobacteriales).getConfidence(), 1.0f, 0.0001);  //Enterobacteriales
		assertEquals( determinedMap.get(Gammaproteobacteria).getConfidence(), 2.0f, 0.0001);  //Gammaproteobacteria
		
		aClassifier.addConfidence(train.getGenusNodebyIndex(4), determinedMap);  //Enterobacter
		aClassifier.addConfidence(train.getGenusNodebyIndex(0), determinedMap);  //Clostridium
		assertEquals( determinedMap.get(determinedGenusNode).getConfidence(), 2.0f, 0.0001);	//Enterobacter
		assertEquals( determinedMap.get(Enterobacteriales).getConfidence(), 2.0f, 0.0001);  //Enterobacteriales
		assertEquals( determinedMap.get(Gammaproteobacteria).getConfidence(), 3.0f, 0.0001);  //Gammaproteobacteria
		assertEquals( determinedMap.get(rootTree).getConfidence(), 4.0f, 0.0001);  //Bacteria
        // end test addConfidence()
        
        
        // end of 
        dstream = TrainingInfoTest.class.getResourceAsStream("/test/classifier/testQuerySeq.fasta");
        in =  new InputStreamReader( dstream );
        BufferedReader infile = new BufferedReader(in);
        // test the first sequence
        String sequence = "";
        infile.readLine();
        sequence = infile.readLine();
        sequence = sequence.toUpperCase();
        ClassifierSequence pSeq = new ClassifierSequence("name", "title", sequence);
        ClassificationResult result = aClassifier.classify(pSeq);
        
        Iterator it = result.getAssignments().iterator();
        RankAssignment classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , rootTree.getName());
        assertEquals(classResult.getConfidence(), 1.0, 0.1);
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , "Proteobacteria");
        it.next();
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , "Rhizobiales");
        it.next();
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , "Rhizobium");
        
        //displayResult(result);
        // end of test first sequence
        
        // test the second sequence
        infile.readLine();
        sequence = infile.readLine();
        sequence = sequence.toUpperCase();
        pSeq = new ClassifierSequence("name", "title", sequence);
        result = aClassifier.classify(pSeq);
        
        
        it = result.getAssignments().iterator();
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , rootTree.getName());
        assertEquals(classResult.getConfidence(), 1.0, 0.1);
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , "Firmicutes");
        it.next();
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , "Clostridiales");
        it.next();
        classResult = (RankAssignment) it.next();
        assertEquals(classResult.getBestClass().getName() , "Clostridium");
        // end of test second sequence
    }
    
    
    
    private void displayResult(ClassificationResult result){
        List assignments = result.getAssignments();
        Iterator it = assignments.iterator();
        while (it.hasNext()){
            RankAssignment classResult = (RankAssignment) it.next();
            System.err.print("\n" + classResult.getBestClass().getName() + "  " + classResult.getConfidence());
        }
    }
    
    
}
