#!/usr/bin/env python

import sys, os, tempfile
import unittest

import gi
gi.require_version('v_sim', '3.8')
from gi.repository import GLib, v_sim

import signals

class TestLoader(unittest.TestCase):
  def setUp(self):
    super(TestLoader, self).setUp()
    self.addTypeEqualityFunc(float, self.fuzzyFloat)

  def fuzzyFloat(self, a, b, msg = None):
    if abs(b-a) > 1e-8:
      raise self.failureException(msg if msg is not None else "%g != %g (d = %g)" % (a, b, abs(a-b)))
        
  def _loader(self, l, path, expectation = {}):
    data = v_sim.DataAtomic.new(path, l)
    try:
      ok = data.load(0, None)
    except:
      ok = False
    self.assertEqual(ok, expectation["success"])

    if ok:
      for (k, v) in expectation.items():
        if k.startswith("data."):
          #print eval(k), v
          self.assertEqual(eval(k), v)

  def _tempFile(self, string):
    f = tempfile.NamedTemporaryFile()
    f.write(string)
    f.flush()
    return f

  def test_ascii_file(self):
    self._loader(v_sim.DataLoader.ascii_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/demo.ascii"),
                   {"success": True,
                    "data.getNNodes()": 172,
                    "data.getNElements(True)": 2,
                    "data.containsElement(v_sim.Element.retrieveFromName(\"H\")[0])": False,
                    "data.containsElement(v_sim.Element.lookup(\"Ni\"))": True,
                    "data.containsElement(v_sim.Element.lookup(\"Au\"))": True,
                    "data.getBox().getUnit()": v_sim.Units.ANGSTROEM,
                    "data.getBox().getBoundary()": v_sim.BoxBoundaries.PERIODIC,
                    "data.getBox().getPeriodicity()": [True, True, True],
                   })
  def test_ascii_wrong_file(self):
    self._loader(v_sim.DataLoader.ascii_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/demo.xyz"),
                   {"success": False})
  def test_ascii_unit_bohr(self):
    with self._tempFile("""test
10 0 10
0  0 10
#keyword: bohr
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.BOHR})
  def test_ascii_unit_angstroem(self):
    with self._tempFile("""test
10 0 10
0  0 10
#keyword: Angstroem
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.ANGSTROEM})
  def test_ascii_unit_undefined(self):
    with self._tempFile("""test
10 0 10
0  0 10
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.UNDEFINED})
  def test_ascii_box_periodic(self):
    with self._tempFile("""test
10 0 10
0  0 10
#keyword: periodic
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.PERIODIC})
    with self._tempFile("""test
10 0 10
0  0 10
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.PERIODIC})
  def test_ascii_box_free(self):
    with self._tempFile("""test
10 0 10
0  0 10
#keyword: freeBC
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.FREE})
  def test_ascii_box_surface(self):
    with self._tempFile("""test
10 0 10
0  0 10
#keyword: surface
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.SURFACE_ZX})
  def test_ascii_box(self):
    with self._tempFile("""test
1 2 3
4 5 6
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getGeometry(v_sim.BoxVector.DXX)": 1.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYX)": 2.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYY)": 3.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZX)": 4.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZY)": 5.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZZ)": 6.})
  def test_ascii_box_angdeg(self):
    with self._tempFile("""test
1 2 3
90 90 90
#keyword: angdeg
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getGeometry(v_sim.BoxVector.DXX)": 1.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYX)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYY)": 2.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZX)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZY)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZZ)": 3.})
  def test_ascii_reduced(self):
    with self._tempFile("""test
1 0 2
0 0 3
#keyword: reduced
0.5 0.75 0.333333333333 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), False)": (0.5, 1.5, 1.)})
  def test_ascii_coord(self):
    with self._tempFile("""test
1 0 2
0 0 3
1 2 3 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), False)": (1., 2., 3.)})
  def test_ascii_coord_user(self):
    with self._tempFile("""test
1 0 2
0 0 3
#keyword: freeBC
1 2 3 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), True)": (1., 2., 3.),
                            "data.getNodeCoordinates(data.getFromId(0), False)": (0., 0., 0.)})
  def test_ascii_props(self):
    with self._tempFile("""test
1 0 2
0 0 3
1 2 3 Si hello
0 0 0 C {"IGSpin": -1}
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeLabelAt(data.getFromId(0))": "hello",
                            "data.getNodeProperties(\"IGSpin\").getAt(data.getFromId(1))[1]": -1})
  def test_ascii_energy(self):
    with self._tempFile("""test
1 0 2
0 0 3
#metaData: totalEnergy = 10.
1 2 3 Si
#
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.get_property(\"totalEnergy\")": 10.})
  def test_ascii_energy_ht(self):
    with self._tempFile("""test
1 0 2
0 0 3
#metaData: totalEnergy = 1. Ht
1 2 3 Si
#
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.get_property(\"totalEnergy\")": 27.21138386})
  def test_ascii_forces(self):
    with self._tempFile("""test
1 0 2
0 0 3
1 2 3 Si
0 0 0 Si
#metaData: forces = [\
# 1, 2, 3, \
# 4, 5, 6]
#
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getForces(False).getAt(data.getFromId(0))": [1, 2, 3],
                            "data.getForces(False).getAt(data.getFromId(1))": [4, 5, 6]})
  def test_ascii_forces_none(self):
    with self._tempFile("""test
1 0 2
0 0 3
1 2 3 Si
0 0 0 Si
""") as f:
      self._loader(v_sim.DataLoader.ascii_getStatic(),
                   f.name, {"success": True,
                            "data.getForces(False)": None,
                            "data.getForces(True).getAt(data.getFromId(0))": [0, 0, 0]})
    
  def test_xyz_file(self):
    self._loader(v_sim.DataLoader.xyz_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/demo.xyz"),
                   {"success": True,
                    "data.getNNodes()": 950,
                    "data.getNElements(True)": 2,
                    "data.containsElement(v_sim.Element.retrieveFromName(\"H\")[0])": False,
                    "data.containsElement(v_sim.Element.lookup(\"Ni\"))": True,
                    "data.containsElement(v_sim.Element.lookup(\"C\"))": True,
                    "data.getBox().getUnit()": v_sim.Units.UNDEFINED,
                    "data.getBox().getBoundary()": v_sim.BoxBoundaries.FREE,
                    "data.getBox().getPeriodicity()": [False, False, False],
                   })
  def test_xyz_wrong_file(self):
    self._loader(v_sim.DataLoader.xyz_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/demo.ascii"),
                   {"success": False})
  def test_xyz_unit_bohr(self):
    with self._tempFile("""1 bohr

Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.BOHR})
  def test_xyz_unit_angstroem(self):
    with self._tempFile("""1 angstroem

Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.ANGSTROEM})
  def test_xyz_unit_undefined(self):
    with self._tempFile("""1

Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.UNDEFINED})
  def test_xyz_box_periodic(self):
    with self._tempFile("""1
periodic 10 10 10 # coming from a surface file
Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.PERIODIC})
  def test_xyz_box_free(self):
    with self._tempFile("""1

Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.FREE})
  def test_xyz_box_surface(self):
    with self._tempFile("""1
surface 10 10 10 # coming from a periodic file
Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.SURFACE_ZX})
  def test_xyz_box(self):
    with self._tempFile("""1
periodic 1 2 3
Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getGeometry(v_sim.BoxVector.DXX)": 1.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYX)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYY)": 2.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZX)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZY)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZZ)": 3.})
  def test_xyz_coord(self):
    with self._tempFile("""1
periodic 10 10 10
Si 1 2 3
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), False)": (1., 2., 3.)})
  def test_xyz_coord_user(self):
    with self._tempFile("""1

Si 1 2 3
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), True)": (1., 2., 3.),
                            "data.getNodeCoordinates(data.getFromId(0), False)": (0., 0., 0.)})
  def test_xyz_props(self):
    with self._tempFile("""2

Si 1 2 3 hello
C 0 0 0 {"IGSpin": -1}
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeLabelAt(data.getFromId(0))": "hello",
                            "data.getNodeProperties(\"IGSpin\").getAt(data.getFromId(1))[1]": -1})
  def test_xyz_forces(self):
    with self._tempFile("""2

Si 1 2 3
Si 0 0 0
FoRceS
Si 1 2 3
Si 4 5 6
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getForces(False).getAt(data.getFromId(0))": [1, 2, 3],
                            "data.getForces(False).getAt(data.getFromId(1))": [4, 5, 6]})
  def test_xyz_forces_none(self):
    with self._tempFile("""2

Si 1 2 3
Si 0 0 0
""") as f:
      self._loader(v_sim.DataLoader.xyz_getStatic(),
                   f.name, {"success": True,
                            "data.getForces(False)": None,
                            "data.getForces(True).getAt(data.getFromId(0))": [0, 0, 0]})
    
  def test_yaml_file(self):
    self._loader(v_sim.DataLoader.yaml_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/cinchonidine.yaml"),
                   {"success": True,
                    "data.getNNodes()": 44,
                    "data.getNElements(True)": 4,
                    "data.containsElement(v_sim.Element.retrieveFromName(\"Ni\")[0])": False,
                    "data.containsElement(v_sim.Element.lookup(\"H\"))": True,
                    "data.containsElement(v_sim.Element.lookup(\"C\"))": True,
                    "data.containsElement(v_sim.Element.lookup(\"O\"))": True,
                    "data.containsElement(v_sim.Element.lookup(\"N\"))": True,
                    "data.getBox().getUnit()": v_sim.Units.BOHR,
                    "data.getBox().getBoundary()": v_sim.BoxBoundaries.FREE,
                    "data.getBox().getPeriodicity()": [False, False, False],
                   })
  def test_yaml_wrong_file(self):
    self._loader(v_sim.DataLoader.yaml_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/demo.ascii"),
                   {"success": False})
  def test_yaml_unit_bohr(self):
    with self._tempFile("""positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.BOHR})
    with self._tempFile("""units: bohr
positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.BOHR})
  def test_yaml_unit_angstroem(self):
    with self._tempFile("""units: angstroem
positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getUnit()": v_sim.Units.ANGSTROEM})
  def test_yaml_box_periodic(self):
    with self._tempFile("""cell: [10., 10., 10.]
positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.PERIODIC})
  def test_yaml_box_free(self):
    with self._tempFile("""positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.FREE})
    with self._tempFile("""cell: [.inf, .inf, .inf]
positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.FREE})
  def test_yaml_box_surface(self):
    with self._tempFile("""cell: [10., .inf, 10.]
positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getBoundary()": v_sim.BoxBoundaries.SURFACE_ZX})
  def test_yaml_box(self):
    with self._tempFile("""cell: [1., 2., 3.]
positions:
  - H: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getBox().getGeometry(v_sim.BoxVector.DXX)": 1.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYX)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DYY)": 2.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZX)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZY)": 0.,
                            "data.getBox().getGeometry(v_sim.BoxVector.DZZ)": 3.})
  def test_yaml_coord(self):
    with self._tempFile("""cell: [10., 10., 10.]
positions:
  - Si: [1., 2., 3.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), False)": (1., 2., 3.)})
  def test_yaml_coord_user(self):
    with self._tempFile("""positions:
  - Si: [1., 2., 3.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeCoordinates(data.getFromId(0), True)": (1., 2., 3.),
                            "data.getNodeCoordinates(data.getFromId(0), False)": (0., 0., 0.)})
  def test_yaml_props(self):
    with self._tempFile("""positions:
  - {Si: [1., 2., 3.]}
  - {C: [0., 0., 0.], IGSpin: -1}
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getNodeLabelAt(data.getFromId(0))": None,
                            "data.getNodeProperties(\"IGSpin\").getAt(data.getFromId(1))[1]": -1})
  def test_yaml_forces(self):
    with self._tempFile("""positions:
  - Si: [1., 2., 3.]
  - C: [0., 0., 0.]
forces:
  Values:
  - Si: [1., 2., 3.]
  - C: [4., 5., 6.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getForces(False).getAt(data.getFromId(0))": [1, 2, 3],
                            "data.getForces(False).getAt(data.getFromId(1))": [4, 5, 6]})
  def test_yaml_forces_none(self):
    with self._tempFile("""positions:
  - Si: [1., 2., 3.]
  - C: [0., 0., 0.]
""") as f:
      self._loader(v_sim.DataLoader.yaml_getStatic(),
                   f.name, {"success": True,
                            "data.getForces(False)": None,
                            "data.getForces(True).getAt(data.getFromId(0))": [0, 0, 0]})

  def test_d3_file(self):
    self._loader(v_sim.DataLoader.d3_getStatic(),
                 os.path.join(os.path.dirname(sys.argv[0]),
                                "../../examples/aluminium.d3"),
                   {"success": True,
                    "data.getNNodes()": 134,
                    "data.getNElements(True)": 1,
                    "data.containsElement(v_sim.Element.retrieveFromName(\"Ni\")[0])": False,
                    "data.containsElement(v_sim.Element.lookup(\"Al\"))": True,
                    "data.getBox().getUnit()": v_sim.Units.UNDEFINED,
                    "data.getBox().getBoundary()": v_sim.BoxBoundaries.PERIODIC,
                    "data.getBox().getGeometry(v_sim.BoxVector.DXX)": 24.72848892,
                    "data.getBox().getGeometry(v_sim.BoxVector.DYX)": -12.36424446,
                    "data.getBox().getGeometry(v_sim.BoxVector.DYY)": 21.41550064,
                    "data.getBox().getGeometry(v_sim.BoxVector.DZX)": 0.,
                    "data.getBox().getGeometry(v_sim.BoxVector.DZY)": 0.,
                    "data.getBox().getGeometry(v_sim.BoxVector.DZZ)": 24.
                   })

if __name__ == '__main__':
    unittest.main()
