# $Id$
#
# Copyright (C) 2007,2008 Greg Landrum
#
#  @@ All Rights Reserved @@
#
import os, sys
import io
import unittest
from rdkit.six.moves import cPickle
from rdkit import RDConfig
from rdkit import DataStructs as ds


def feq(v1, v2, tol=1e-4):
  return abs(v1 - v2) < tol


class TestCase(unittest.TestCase):

  def setUp(self):
    pass

  def test1Int(self):
    """

    """
    v1 = ds.IntSparseIntVect(5)
    self.assertRaises(IndexError, lambda: v1[5])
    v1[0] = 1
    v1[2] = 2
    v1[3] = 3
    self.assertTrue(v1 == v1)
    self.assertTrue(v1.GetLength() == 5)

    v2 = ds.IntSparseIntVect(5)
    self.assertTrue(v1 != v2)
    v2 |= v1
    self.assertTrue(v2 == v1)

    v3 = v2 | v1
    self.assertTrue(v3 == v1)

    onVs = v1.GetNonzeroElements()
    self.assertTrue(onVs == {0: 1, 2: 2, 3: 3})

  def test2Long(self):
    """

    """
    l = 1 << 42
    v1 = ds.LongSparseIntVect(l)
    self.assertRaises(IndexError, lambda: v1[l])
    v1[0] = 1
    v1[2] = 2
    v1[1 << 35] = 3
    self.assertTrue(v1 == v1)
    self.assertTrue(v1.GetLength() == l)

    v2 = ds.LongSparseIntVect(l)
    self.assertTrue(v1 != v2)
    v2 |= v1
    self.assertTrue(v2 == v1)

    v3 = v2 | v1
    self.assertTrue(v3 == v1)

    onVs = v1.GetNonzeroElements()
    self.assertTrue(onVs == {0: 1, 2: 2, 1 << 35: 3})

  def test3Pickle1(self):
    """

    """
    l = 1 << 42
    v1 = ds.LongSparseIntVect(l)
    self.assertRaises(IndexError, lambda: v1[l + 1])
    v1[0] = 1
    v1[2] = 2
    v1[1 << 35] = 3
    self.assertTrue(v1 == v1)

    v2 = cPickle.loads(cPickle.dumps(v1))
    self.assertTrue(v2 == v1)

    v3 = ds.LongSparseIntVect(v2.ToBinary())
    self.assertTrue(v2 == v3)
    self.assertTrue(v1 == v3)

    #cPickle.dump(v1,file('lsiv.pkl','wb+'))
    with open(os.path.join(RDConfig.RDBaseDir, 'Code/DataStructs/Wrap/testData/lsiv.pkl'),
              'r') as tf:
      buf = tf.read().replace('\r\n', '\n').encode('utf-8')
      tf.close()
    with io.BytesIO(buf) as f:
      v3 = cPickle.load(f)
      self.assertTrue(v3 == v1)

  def test3Pickle2(self):
    """

    """
    l = 1 << 21
    v1 = ds.IntSparseIntVect(l)
    self.assertRaises(IndexError, lambda: v1[l + 1])
    v1[0] = 1
    v1[2] = 2
    v1[1 << 12] = 3
    self.assertTrue(v1 == v1)

    v2 = cPickle.loads(cPickle.dumps(v1))
    self.assertTrue(v2 == v1)

    v3 = ds.IntSparseIntVect(v2.ToBinary())
    self.assertTrue(v2 == v3)
    self.assertTrue(v1 == v3)

    #cPickle.dump(v1,file('isiv.pkl','wb+'))
    with open(os.path.join(RDConfig.RDBaseDir, 'Code/DataStructs/Wrap/testData/isiv.pkl'),
              'r') as tf:
      buf = tf.read().replace('\r\n', '\n').encode('utf-8')
      tf.close()
    with io.BytesIO(buf) as f:
      v3 = cPickle.load(f)
      self.assertTrue(v3 == v1)

  def test4Update(self):
    """

    """
    v1 = ds.IntSparseIntVect(5)
    self.assertRaises(IndexError, lambda: v1[6])
    v1[0] = 1
    v1[2] = 2
    v1[3] = 3
    self.assertTrue(v1 == v1)

    v2 = ds.IntSparseIntVect(5)
    v2.UpdateFromSequence((0, 2, 3, 3, 2, 3))
    self.assertTrue(v1 == v2)

  def test5Dice(self):
    """

    """
    v1 = ds.IntSparseIntVect(5)
    v1[4] = 4
    v1[0] = 2
    v1[3] = 1
    self.assertTrue(feq(ds.DiceSimilarity(v1, v1), 1.0))

    v1 = ds.IntSparseIntVect(5)
    v1[0] = 2
    v1[2] = 1
    v1[3] = 4
    v1[4] = 6
    v2 = ds.IntSparseIntVect(5)
    v2[1] = 2
    v2[2] = 3
    v2[3] = 4
    v2[4] = 4
    self.assertTrue(feq(ds.DiceSimilarity(v1, v2), 18.0 / 26.))
    self.assertTrue(feq(ds.DiceSimilarity(v2, v1), 18.0 / 26.))

  def test6BulkDice(self):
    """

    """
    sz = 10
    nToSet = 5
    nVs = 6
    import random
    vs = []
    for i in range(nVs):
      v = ds.IntSparseIntVect(sz)
      for j in range(nToSet):
        v[random.randint(0, sz - 1)] = random.randint(1, 10)
      vs.append(v)

    baseDs = [ds.DiceSimilarity(vs[0], vs[x]) for x in range(1, nVs)]
    bulkDs = ds.BulkDiceSimilarity(vs[0], vs[1:])
    for i in range(len(baseDs)):
      self.assertTrue(feq(baseDs[i], bulkDs[i]))

  def test6BulkTversky(self):
    """

    """
    sz = 10
    nToSet = 5
    nVs = 6
    import random
    vs = []
    for i in range(nVs):
      v = ds.IntSparseIntVect(sz)
      for j in range(nToSet):
        v[random.randint(0, sz - 1)] = random.randint(1, 10)
      vs.append(v)

    baseDs = [ds.TverskySimilarity(vs[0], vs[x], .5, .5) for x in range(1, nVs)]
    bulkDs = ds.BulkTverskySimilarity(vs[0], vs[1:], 0.5, 0.5)
    diceDs = [ds.DiceSimilarity(vs[0], vs[x]) for x in range(1, nVs)]
    for i in range(len(baseDs)):
      self.assertTrue(feq(baseDs[i], bulkDs[i]))
      self.assertTrue(feq(baseDs[i], diceDs[i]))

    bulkDs = ds.BulkTverskySimilarity(vs[0], vs[1:], 1.0, 1.0)
    taniDs = [ds.TanimotoSimilarity(vs[0], vs[x]) for x in range(1, nVs)]
    for i in range(len(bulkDs)):
      self.assertTrue(feq(bulkDs[i], taniDs[i]))
    taniDs = ds.BulkTanimotoSimilarity(vs[0], vs[1:])
    for i in range(len(bulkDs)):
      self.assertTrue(feq(bulkDs[i], taniDs[i]))


if __name__ == '__main__':
  unittest.main()
