import sys

import mpiunittest as unittest

from mpi4py import MPI


class TestInfoNull(unittest.TestCase):
    #
    def testTruth(self):
        self.assertFalse(bool(MPI.INFO_NULL))

    def testPickle(self):
        from pickle import dumps, loads

        null = loads(dumps(MPI.INFO_NULL))
        self.assertIs(null, MPI.INFO_NULL)
        null = loads(dumps(MPI.Info()))
        self.assertIsNot(null, MPI.INFO_NULL)
        self.assertEqual(null, MPI.INFO_NULL)

    def testPyMethods(self):
        inull = MPI.INFO_NULL

        def getitem():
            return inull["k"]

        def setitem():
            inull["k"] = "v"

        def delitem():
            del inull["k"]

        def update():
            inull.update([])

        def pop():
            inull.pop("k")

        def popitem():
            inull.popitem()

        self.assertEqual(len(inull), 0)
        self.assertNotIn("key", inull)
        self.assertRaises(KeyError, getitem)
        self.assertRaises(KeyError, setitem)
        self.assertRaises(KeyError, delitem)
        self.assertRaises(KeyError, update)
        self.assertRaises(KeyError, pop)
        self.assertRaises(KeyError, popitem)
        self.assertIsNone(inull.get("key", None))
        self.assertIsNone(inull.pop("key", None))
        self.assertEqual(inull.keys(), [])
        self.assertEqual(inull.values(), [])
        self.assertEqual(inull.items(), [])
        self.assertEqual(inull.copy(), inull)
        inull.clear()


class TestInfoEnv(unittest.TestCase):
    #
    KEYS = (
        "command",
        "argv",
        "maxprocs",
        "soft",
        "host",
        "arch",
        "wdir",
        "file",
        "thread_level",
    )

    def testTruth(self):
        self.assertTrue(bool(MPI.INFO_ENV))

    def testPickle(self):
        from pickle import dumps, loads

        ienv = loads(dumps(MPI.INFO_ENV))
        self.assertIs(ienv, MPI.INFO_ENV)
        ienv = loads(dumps(MPI.Info(MPI.INFO_ENV)))
        self.assertIsNot(ienv, MPI.INFO_ENV)
        self.assertEqual(ienv, MPI.INFO_ENV)

    def testPyMethods(self):
        env = MPI.INFO_ENV
        for key in self.KEYS:
            v = env.Get(key)
            del v

    def testDup(self):
        env = MPI.INFO_ENV
        dup = env.Dup()
        try:
            for key in self.KEYS:
                self.assertEqual(env.Get(key), dup.Get(key))
        finally:
            dup.Free()

    def testCreateEnv(self):
        try:
            env = MPI.Info.Create_env()
        except NotImplementedError:
            if MPI.Get_version() >= (4, 0):
                raise
            raise unittest.SkipTest("mpi-info-create-env") from None
        for key in self.KEYS:
            v = env.Get(key)
            del v
        try:
            dup = env.Dup()
            try:
                for key in self.KEYS:
                    self.assertEqual(env.Get(key), dup.Get(key))
            finally:
                dup.Free()
        finally:
            env.Free()
        for args in (
            None,
            [],
            (),
            sys.executable,
            [sys.executable],
            (sys.executable,),
        ):
            MPI.Info.Create_env(args).Free()
            MPI.Info.Create_env(args=args).Free()


class TestInfo(unittest.TestCase):
    #
    def setUp(self):
        self.INFO = MPI.Info.Create()

    def tearDown(self):
        self.INFO.Free()
        self.assertEqual(self.INFO, MPI.INFO_NULL)
        self.INFO = None

    def testTruth(self):
        self.assertTrue(bool(self.INFO))

    def testCreate(self):
        data = {"key1": "value1", "key2": "value2"}
        for items in (None, {}, [], data, list(data.items())):
            info = MPI.Info.Create(items)
            if items is not None:
                self.assertEqual(info.Get_nkeys(), len(items))
                for k, v in dict(items).items():
                    self.assertEqual(info.Get(k), v)
            info.Free()

    def testCreateBad(self):
        with self.assertRaises(TypeError):
            MPI.Info.Create(items=123)

    def testDup(self):
        info = self.INFO.Dup()
        self.assertNotEqual(self.INFO, info)
        self.assertEqual(info.Get_nkeys(), 0)
        info.Free()
        self.assertFalse(info)

    def testGet(self):
        value = self.INFO.Get("key")
        self.assertIsNone(value)

    def testGetNKeys(self):
        self.assertEqual(self.INFO.Get_nkeys(), 0)

    def testGetSetDelete(self):
        INFO = self.INFO
        self.assertEqual(INFO.Get_nkeys(), 0)
        INFO.Set("key", "value")
        nkeys = INFO.Get_nkeys()
        self.assertEqual(nkeys, 1)
        key = INFO.Get_nthkey(0)
        self.assertEqual(key, "key")
        value = INFO.Get("key")
        self.assertEqual(value, "value")
        INFO.Delete("key")
        nkeys = INFO.Get_nkeys()
        self.assertEqual(nkeys, 0)
        value = INFO.Get("key")
        self.assertIsNone(value)

    def testPickle(self):
        from pickle import dumps, loads

        mold = self.INFO
        info = loads(dumps(mold))
        self.assertIsNot(info, mold)
        self.assertNotEqual(info, mold)
        self.assertEqual(info.items(), mold.items())
        info.Free()
        mold = self.INFO
        mold.update([("foo", "bar"), ("answer", "42")])
        info = loads(dumps(mold))
        self.assertIsNot(info, mold)
        self.assertNotEqual(info, mold)
        self.assertEqual(info.items(), mold.items())
        info.Free()

    def testPyMethods(self):
        INFO = self.INFO

        self.assertEqual(len(INFO), 0)
        self.assertNotIn("key", INFO)
        self.assertEqual(INFO.keys(), [])
        self.assertEqual(INFO.values(), [])
        self.assertEqual(INFO.items(), [])

        INFO["key"] = "value"
        self.assertEqual(len(INFO), 1)
        self.assertIn("key", INFO)
        self.assertEqual(INFO["key"], "value")
        for key in INFO:
            self.assertEqual(key, "key")
        self.assertEqual(INFO.keys(), ["key"])
        self.assertEqual(INFO.values(), ["value"])
        self.assertEqual(INFO.items(), [("key", "value")])
        self.assertEqual(key, "key")
        del INFO["key"]
        self.assertEqual(len(INFO), 0)

        INFO["key"] = "value"
        self.assertEqual(INFO.pop("key"), "value")
        self.assertEqual(len(INFO), 0)
        self.assertEqual(INFO.pop("key", "value"), "value")
        self.assertRaises(KeyError, INFO.pop, "key")
        INFO["key1"] = "value1"
        INFO["key2"] = "value2"
        self.assertEqual(INFO.pop("key1"), "value1")
        self.assertEqual(len(INFO), 1)
        self.assertEqual(INFO.pop("key2"), "value2")
        self.assertEqual(len(INFO), 0)

        INFO["key"] = "value"
        self.assertEqual(INFO.popitem(), ("key", "value"))
        self.assertEqual(len(INFO), 0)
        self.assertRaises(KeyError, INFO.popitem)
        INFO["key1"] = "value1"
        INFO["key2"] = "value2"
        self.assertEqual(INFO.popitem(), ("key2", "value2"))
        self.assertEqual(len(INFO), 1)
        self.assertEqual(INFO.popitem(), ("key1", "value1"))
        self.assertEqual(len(INFO), 0)

        self.assertEqual(len(INFO), 0)
        self.assertNotIn("key", INFO)
        self.assertEqual(INFO.keys(), [])
        self.assertEqual(INFO.values(), [])
        self.assertEqual(INFO.items(), [])

        def getitem():
            INFO["key"]

        self.assertRaises(KeyError, getitem)

        def delitem():
            del INFO["key"]

        self.assertRaises(KeyError, delitem)

        INFO.clear()
        INFO.update([("key1", "value1")])
        self.assertEqual(len(INFO), 1)
        self.assertEqual(INFO["key1"], "value1")
        self.assertEqual(INFO.get("key1"), "value1")
        self.assertIsNone(INFO.get("key2"))
        self.assertEqual(INFO.get("key2", "value2"), "value2")
        INFO.update(key2="value2")
        self.assertEqual(len(INFO), 2)
        self.assertEqual(INFO["key1"], "value1")
        self.assertEqual(INFO["key2"], "value2")
        self.assertEqual(INFO.get("key1"), "value1")
        self.assertEqual(INFO.get("key2"), "value2")
        self.assertIsNone(INFO.get("key3"))
        self.assertEqual(INFO.get("key3", "value3"), "value3")
        INFO.update([("key1", "newval1")], key2="newval2")
        self.assertEqual(len(INFO), 2)
        self.assertEqual(INFO["key1"], "newval1")
        self.assertEqual(INFO["key2"], "newval2")
        self.assertEqual(INFO.get("key1"), "newval1")
        self.assertEqual(INFO.get("key2"), "newval2")
        self.assertIsNone(INFO.get("key3"))
        self.assertEqual(INFO.get("key3", "newval3"), "newval3")
        INFO.update({"key1": "val1", "key2": "val2", "key3": "val3"})
        self.assertEqual(len(INFO), 3)
        self.assertEqual(INFO["key1"], "val1")
        self.assertEqual(INFO["key2"], "val2")
        self.assertEqual(INFO["key3"], "val3")
        dupe = INFO.copy()
        self.assertEqual(INFO.items(), dupe.items())
        dupe.Free()
        INFO.clear()
        self.assertEqual(len(INFO), 0)
        self.assertIsNone(INFO.get("key1"))
        self.assertIsNone(INFO.get("key2"))
        self.assertIsNone(INFO.get("key3"))
        self.assertEqual(INFO.get("key1", "value1"), "value1")
        self.assertEqual(INFO.get("key2", "value2"), "value2")
        self.assertEqual(INFO.get("key3", "value3"), "value3")


try:
    MPI.Info.Create().Free()
except NotImplementedError:
    unittest.disable(TestInfo, "mpi-info")
    unittest.disable(TestInfoNull, "mpi-info")
if MPI.Get_version() < (3, 0) or MPI.INFO_ENV == MPI.INFO_NULL:
    unittest.disable(TestInfoEnv, "mpi-info-env")


if __name__ == "__main__":
    unittest.main()
