from __future__ import print_function
import unittest
import os
import io

from rdkit.six.moves import cPickle as pickle

from rdkit import Chem
from rdkit.Chem import rdPartialCharges
from rdkit import RDConfig


def feq(v1, v2, tol2=1e-4):
  return abs(v1 - v2) <= tol2


class TestCase(unittest.TestCase):

  def setUp(self):
    pass

  def test0HalgrenSet(self):
    smiSup = Chem.SmilesMolSupplier(
      os.path.join(RDConfig.RDBaseDir, 'Code', 'GraphMol', 'PartialCharges', 'Wrap', 'test_data',
                   'halgren.smi'), delimiter='\t')

    #parse the original file
    with open(
        os.path.join(RDConfig.RDBaseDir, 'Code', 'GraphMol', 'PartialCharges', 'Wrap', 'test_data',
                     'halgren_out.txt'), 'r') as infil:
      lines = infil.readlines()

    tab = Chem.GetPeriodicTable()

    olst = []
    for mol in smiSup:
      rdPartialCharges.ComputeGasteigerCharges(mol)
      tstr = "Molecule: "
      tstr += mol.GetProp("_Name")
      olst.append(tstr)
      for i in range(mol.GetNumAtoms()):
        at = mol.GetAtomWithIdx(i)
        en = tab.GetElementSymbol(at.GetAtomicNum())
        chg = float(at.GetProp("_GasteigerCharge"))
        tstr = "%i %s %6.4f" % (i, en, chg)
        olst.append(tstr)

    i = 0
    for line in lines:
      self.assertTrue(line.strip() == olst[i])
      i += 1

  def test1PPDataset(self):
    fileN = os.path.join(RDConfig.RDBaseDir, 'Code', 'GraphMol', 'PartialCharges', 'Wrap',
                         'test_data', 'PP_descrs_regress.2.csv')
    infil = open(fileN, 'r')
    lines = infil.readlines()
    infil.close()

    infile = os.path.join(RDConfig.RDBaseDir, 'Code', 'GraphMol', 'PartialCharges', 'Wrap',
                          'test_data', 'PP_combi_charges.pkl')
    with open(infile, 'r') as cchtFile:
      buf = cchtFile.read().replace('\r\n', '\n').encode('utf-8')
      cchtFile.close()
    with io.BytesIO(buf) as cchFile:
      combiCharges = pickle.load(cchFile)

    for lin in lines:
      if (lin[0] == '#'):
        continue
      tlst = lin.strip().split(',')
      smi = tlst[0]
      rdmol = Chem.MolFromSmiles(smi)
      rdPartialCharges.ComputeGasteigerCharges(rdmol)

      nat = rdmol.GetNumAtoms()
      failed = False
      for ai in range(nat):
        rdch = float(rdmol.GetAtomWithIdx(ai).GetProp('_GasteigerCharge'))
        if not feq(rdch, combiCharges[smi][ai], 1.e-2):
          failed = True
          print(smi, ai, rdch, combiCharges[smi][ai])
      if failed:
        rdmol.Debug()
      self.assertFalse(failed)

  def test2Params(self):
    """ tests handling of Issue187 """
    m1 = Chem.MolFromSmiles('C(=O)[O-]')
    rdPartialCharges.ComputeGasteigerCharges(m1)

    m2 = Chem.MolFromSmiles('C(=O)[O-].[Na+]')
    rdPartialCharges.ComputeGasteigerCharges(m2)

    for i in range(m1.GetNumAtoms()):
      c1 = float(m1.GetAtomWithIdx(i).GetProp('_GasteigerCharge'))
      c2 = float(m2.GetAtomWithIdx(i).GetProp('_GasteigerCharge'))
      self.assertTrue(feq(c1, c2, 1e-4))

  def test3Params(self):
    """ tests handling of Issue187 """
    m2 = Chem.MolFromSmiles('C(=O)[O-].[Na+]')
    with self.assertRaisesRegexp(Exception, ""):
      rdPartialCharges.ComputeGasteigerCharges(m2, 12, 1)

  def testGithubIssue20(self):
    """ tests handling of Github issue 20 """
    m1 = Chem.MolFromSmiles('CB(O)O')
    rdPartialCharges.ComputeGasteigerCharges(m1)
    chgs = [-0.030, 0.448, -0.427, -0.427]
    for i in range(m1.GetNumAtoms()):
      c1 = float(m1.GetAtomWithIdx(i).GetProp('_GasteigerCharge'))
      self.assertAlmostEqual(c1, chgs[i], 3)

  def testGithubIssue577(self):
    """ tests handling of Github issue 577 """
    m1 = Chem.MolFromSmiles('CCO')
    from locale import setlocale, LC_NUMERIC
    try:
      setlocale(LC_NUMERIC, "de_DE")
    except Exception:
      # can't set the required locale, might as well just return
      return
    try:
      rdPartialCharges.ComputeGasteigerCharges(m1)
      for at in m1.GetAtoms():
        float(at.GetProp('_GasteigerCharge'))
    finally:
      setlocale(LC_NUMERIC, "C")
    rdPartialCharges.ComputeGasteigerCharges(m1)
    for at in m1.GetAtoms():
      float(at.GetProp('_GasteigerCharge'))


if __name__ == '__main__':
  unittest.main()
