import unittest
import numpy
import os
import io

from rdkit.six.moves import cPickle

from rdkit import RDConfig,RDRandom
from rdkit.ML.InfoTheory import rdInfoTheory as rdit
from rdkit import DataStructs

def feq(a,b,tol=1e-4):
    return abs(a-b)<tol

class TestCase(unittest.TestCase):

    def setUp(self) :
        pass

    def test0GainFuns(self):
       arr = numpy.array([9,5])
       self.assertTrue(feq(rdit.InfoEntropy(arr),0.9403))
       arr = numpy.array([9,9])
       self.assertTrue(feq(rdit.InfoEntropy(arr),1.0000))
       arr = numpy.array([5,5])
       self.assertTrue(feq(rdit.InfoEntropy(arr),1.0000))
       arr = numpy.array([5,0])
       self.assertTrue(feq(rdit.InfoEntropy(arr),0.0000))
       arr = numpy.array([5,5,5])
       self.assertTrue(feq(rdit.InfoEntropy(arr),1.5850))
       arr = numpy.array([2,5,5])
       self.assertTrue(feq(rdit.InfoEntropy(arr),1.4834))

       
       mat2 = numpy.array([[6,2],[3,3]])
       self.assertTrue(feq(rdit.InfoGain(mat2),0.0481))
       self.assertTrue(feq(rdit.ChiSquare(mat2),0.9333))
       
       mat3 = numpy.array([[1,1],[2,1]])
       self.assertTrue(feq(rdit.InfoGain(mat3),0.0200))

       
       mat4 = numpy.array([[2,0],[1,2]])
       self.assertTrue(feq(rdit.InfoGain(mat4),0.4200))


       mat5 = numpy.array([[0,0],[0,0]])
       self.assertTrue(feq(rdit.InfoGain(mat5),0.0000))


       mat6 = numpy.array([[1,0],[1,0]])
       self.assertTrue(feq(rdit.InfoGain(mat6),0.0000))


       
       
    def test1ranker(self) :
        nbits = 100
        ninst = 100
        dm = 50
        nact = 10
        nc = 2
        rn = rdit.InfoBitRanker(nbits, nc, rdit.InfoType.ENTROPY)
        fps = []
        na = 0
        ni = 0
        for i in range(ninst) :
            v = DataStructs.SparseBitVect(nbits)
            for j in range(dm):
                v.SetBit(RDRandom.randrange(0,nbits))

            
            if (RDRandom.randrange(0,ninst) < nact) :
                na += 1
                rn.AccumulateVotes(v, 1)
                fps.append((v,1))
            else:
                ni += 1
                rn.AccumulateVotes(v, 0)
                fps.append((v,0))
                
        res =  rn.GetTopN(50)

        rn2 = rdit.InfoBitRanker(nbits, nc)
        for fp in fps:
            rn2.AccumulateVotes(fp[0], fp[1])

        res2 = rn2.GetTopN(50)
        self.assertTrue((res==res2).all())
        
        rn3 = rdit.InfoBitRanker(nbits, nc, rdit.InfoType.BIASENTROPY)
        #rn3.SetBiasList([0])
        for fp in fps:
            rn3.AccumulateVotes(fp[0], fp[1])

        res3 = rn3.GetTopN(50)
        for i in range(50) :
            fan = res3[i,2]/na
            fin = res3[i,3]/ni
            self.assertTrue(fan > fin)
                          
    def test2ranker(self) :
        nbits = 100
        ninst = 100
        dm = 50
        nact = 10
        nc = 2
        RDRandom.seed(23)
        rn = rdit.InfoBitRanker(nbits, nc, rdit.InfoType.ENTROPY)
        rn.SetMaskBits([63,70,15,25,10])
        fps = []
        na = 0
        ni = 0
        for i in range(ninst) :
            v = DataStructs.SparseBitVect(nbits)
            for j in range(dm):
                v.SetBit(RDRandom.randrange(0,nbits))
            if (RDRandom.randrange(0,ninst) < nact) :
                na += 1
                rn.AccumulateVotes(v, 1)
                fps.append((v,1))
            else:
                ni += 1
                rn.AccumulateVotes(v, 0)
                fps.append((v,0))
        res =  rn.GetTopN(5)
        ids = [int(x[0]) for x in res]
        ids.sort()
        self.assertTrue(ids==[10,15,25,63,70])
        with self.assertRaisesRegexp(Exception, ""):
          res = rn.GetTopN(10)

    def test3Issue140(self) :
        nbits = 2
        examples = [[0,0,0],[1,1,0],[0,0,1],[1,1,1]]
        rn = rdit.InfoBitRanker(2,2,rdit.InfoType.ENTROPY)
        for example in examples:
            act = example.pop(-1)
            bv = DataStructs.ExplicitBitVect(2)
            for i in range(2):
                bv[i] = example[i]
            rn.AccumulateVotes(bv,act)
        try:
            res =  rn.GetTopN(1)
        except Exception:
            res = None
        self.assertTrue(res is not None)    

    def test4Issue237(self) :
        with open(os.path.join(RDConfig.RDBaseDir,'Code','ML','InfoTheory','Wrap','testData','Issue237.pkl'),'r') as inTF:
            buf = inTF.read().replace('\r\n', '\n').encode('utf-8')
            inTF.close()
        with io.BytesIO(buf) as inF:
            examples,avail,bias,nB,nPoss = cPickle.load(inF, encoding='bytes')
        ranker = rdit.InfoBitRanker(nB,nPoss,rdit.InfoType.BIASENTROPY)
        ranker.SetMaskBits(avail)
        for ex in examples:
            ranker.AccumulateVotes(ex[1],ex[-1])
        # this dumps core on linux if the bug isn't fixed:
        v=ranker.GetTopN(1)
        self.assertTrue(int(v[0][0])==12)
                          
if __name__ == '__main__':
    unittest.main()
                       
