import struct

import mpitestutil as testutil
import mpiunittest as unittest

from mpi4py import MPI

import platform
test_machine = platform.machine()

datatypes_c = [
    MPI.CHAR,
    MPI.WCHAR,
    MPI.SIGNED_CHAR,
    MPI.SHORT,
    MPI.INT,
    MPI.LONG,
    MPI.UNSIGNED_CHAR,
    MPI.UNSIGNED_SHORT,
    MPI.UNSIGNED,
    MPI.UNSIGNED_LONG,
    MPI.LONG_LONG,
    MPI.UNSIGNED_LONG_LONG,
    MPI.FLOAT,
    MPI.DOUBLE,
    MPI.LONG_DOUBLE,
]
datatypes_c99 = [
    MPI.C_BOOL,
    MPI.INT8_T,
    MPI.INT16_T,
    MPI.INT32_T,
    MPI.INT64_T,
    MPI.UINT8_T,
    MPI.UINT16_T,
    MPI.UINT32_T,
    MPI.UINT64_T,
    MPI.FLOAT16_T,
    MPI.FLOAT32_T,
    MPI.FLOAT64_T,
    MPI.BFLOAT16_T,
    MPI.C_COMPLEX,
    MPI.C_FLOAT_COMPLEX,
    MPI.C_DOUBLE_COMPLEX,
    MPI.C_LONG_DOUBLE_COMPLEX,
]
datatypes_f77 = [
    MPI.CHARACTER,
    MPI.LOGICAL,
    MPI.INTEGER,
    MPI.REAL,
    MPI.DOUBLE_PRECISION,
    MPI.COMPLEX,
    MPI.DOUBLE_COMPLEX,
]
datatypes_f90 = [
    MPI.LOGICAL1,
    MPI.LOGICAL2,
    MPI.LOGICAL4,
    MPI.LOGICAL8,
    MPI.LOGICAL16,
    MPI.INTEGER1,
    MPI.INTEGER2,
    MPI.INTEGER4,
    MPI.INTEGER8,
    MPI.INTEGER16,
    MPI.REAL2,
    MPI.REAL4,
    MPI.REAL8,
    MPI.REAL16,
    MPI.COMPLEX4,
    MPI.COMPLEX8,
    MPI.COMPLEX16,
    MPI.COMPLEX32,
]
datatypes_mpi = [
    MPI.PACKED,
    MPI.BYTE,
    MPI.AINT,
    MPI.OFFSET,
]

for typelist in [
    datatypes_c,
    datatypes_c99,
    datatypes_f77,
    datatypes_f90,
]:
    typelist[:] = [t for t in typelist if testutil.has_datatype(t)]
del typelist

datatypes = []
datatypes += datatypes_c
datatypes += datatypes_c99
datatypes += datatypes_f77
datatypes += datatypes_f90
datatypes += datatypes_mpi

combiner_map = {}


class TestDatatypeNull(unittest.TestCase):
    #
    def testConstructor(self):
        datatype = MPI.Datatype()
        self.assertEqual(datatype, MPI.DATATYPE_NULL)
        self.assertIsNot(datatype, MPI.DATATYPE_NULL)

        def construct():
            MPI.Datatype((1, 2, 3))

        self.assertRaises(TypeError, construct)

    def testGetName(self):
        name = MPI.DATATYPE_NULL.Get_name()
        self.assertEqual(name, "MPI_DATATYPE_NULL")


class TestDatatype(unittest.TestCase):
    #
    def testBoolEqNe(self):
        for dtype in datatypes:
            self.assertTrue(not not dtype)
            eq = dtype == MPI.Datatype(dtype)
            ne = dtype != MPI.Datatype(dtype)
            self.assertTrue(eq)
            self.assertFalse(ne)

    def testGetExtent(self):
        for dtype in datatypes:
            lb, ext = dtype.Get_extent()
            self.assertEqual(dtype.lb, lb)
            self.assertEqual(dtype.ub, lb + ext)
            self.assertEqual(dtype.extent, ext)

    def testGetSize(self):
        for dtype in datatypes:
            size = dtype.Get_size()
            self.assertEqual(dtype.size, size)

    def testGetTrueExtent(self):
        for dtype in datatypes:
            try:
                lb, ext = dtype.Get_true_extent()
                self.assertEqual(dtype.true_lb, lb)
                self.assertEqual(dtype.true_ub, lb + ext)
                self.assertEqual(dtype.true_extent, ext)
            except NotImplementedError:
                self.skipTest("mpi-type-get_true_extent")

    match_size_logical = [1, 2, 4, 8, 16]
    match_size_integer = [1, 2, 4, 8, 16]
    match_size_real = [2, 4, 8]
    match_size_complex = [4, 8, 16]

    @unittest.skipMPI("MPI(<2.0)")
    @unittest.skipMPI("openmpi", not testutil.has_datatype(MPI.INTEGER))
    def testMatchSize(self):
        for key in ("logical", "integer", "real", "complex"):
            match_size = getattr(self, f"match_size_{key}")
            typeclass = getattr(MPI, f"TYPECLASS_{key.upper()}")
            if typeclass == MPI.UNDEFINED:
                continue
            for size in match_size:
                try:
                    datatype = MPI.Datatype.Match_size(typeclass, size)
                except MPI.Exception:
                    datatype = getattr(MPI, f"{key.upper()}{size}")
                    if testutil.has_datatype(datatype):
                        raise
                else:
                    if testutil.has_datatype(datatype):
                        self.assertEqual(datatype.size, size)

    def testGetValueIndex(self):
        typenames = ("SHORT", "INT", "LONG", "FLOAT", "DOUBLE", "LONG_DOUBLE")
        value_types = [getattr(MPI, f"{attr}") for attr in typenames]
        pair_types = [getattr(MPI, f"{attr}_INT") for attr in typenames]
        for value, pair in zip(value_types, pair_types):
            result = MPI.Datatype.Get_value_index(value, MPI.INT)
            self.assertEqual(result, pair)
        for value in value_types:
            result = MPI.Datatype.Get_value_index(value, MPI.FLOAT)
            self.assertEqual(result, MPI.DATATYPE_NULL)
        with self.assertRaises(MPI.Exception) as catcher:
            MPI.Datatype.Get_value_index(MPI.DATATYPE_NULL, MPI.INT)
        self.assertEqual(catcher.exception.Get_error_class(), MPI.ERR_TYPE)
        with self.assertRaises(MPI.Exception) as catcher:
            MPI.Datatype.Get_value_index(MPI.INT, MPI.DATATYPE_NULL)
        self.assertEqual(catcher.exception.Get_error_class(), MPI.ERR_TYPE)

    def testGetEnvelope(self):
        for dtype in datatypes:
            try:
                envelope = dtype.Get_envelope()
            except NotImplementedError:
                self.skipTest("mpi-type-get_envelope")
            ni, na, nc, nd, combiner = envelope
            self.assertEqual(combiner, MPI.COMBINER_NAMED)
            self.assertEqual(ni, 0)
            self.assertEqual(na, 0)
            self.assertEqual(nc, 0)
            self.assertEqual(nd, 0)
            self.assertEqual(dtype.envelope, envelope)
            self.assertEqual(dtype.combiner, combiner)
            self.assertTrue(dtype.is_named)
            self.assertTrue(dtype.is_predefined)
            otype, combiner, params = dtype.decode()
            self.assertIs(dtype, otype)
            self.assertEqual(combiner, "NAMED")
            self.assertEqual(params, {})

    def testGetSetName(self):
        name = MPI.DATATYPE_NULL.Get_name()
        self.assertEqual(name, "MPI_DATATYPE_NULL")
        for dtype in datatypes:
            try:
                name = dtype.Get_name()
                self.assertTrue(name)
                dtype.Set_name(name)
                self.assertEqual(name, dtype.Get_name())
                dtype.name = dtype.name
            except NotImplementedError:
                self.skipTest("mpi-type-name")

    def testCommit(self):
        for dtype in datatypes:
            dtype.Commit()

    @unittest.skipIf('ppc' in test_machine, "testCodeCharStr fails on ppc arches")
    def testCodeCharStr(self):
        f90datatypes = []
        try:
            try:
                for r in (1, 2, 4):
                    f90datatypes.append(MPI.Datatype.Create_f90_integer(r))
                for p, r in ((6, 30), (15, 300)):
                    f90datatypes.append(MPI.Datatype.Create_f90_real(p, r))
                for p, r in ((6, 30), (15, 300)):
                    f90datatypes.append(MPI.Datatype.Create_f90_complex(p, r))
            except MPI.Exception:
                if not unittest.is_mpi("msmpi"):
                    raise
            f90datatypes = [
                dtype for dtype in f90datatypes if dtype and dtype.size > 0
            ]
        except NotImplementedError:
            f90datatypes = []
        largef90datatypes = []
        if MPI.INTEGER16 != MPI.DATATYPE_NULL:
            largef90datatypes += [MPI.INTEGER16]
        if (
            struct.calcsize("P") == 4
            or MPI.DOUBLE.extent == MPI.LONG_DOUBLE.extent
        ):
            largef90datatypes += [MPI.REAL16, MPI.COMPLEX32]
        for dtype in datatypes + f90datatypes:
            with self.subTest(datatype=dtype.name or "f90"):
                if dtype == MPI.BFLOAT16_T:
                    continue
                if dtype in largef90datatypes:
                    continue
                code = dtype.tocode()
                self.assertIsNotNone(code)
                mpitype = MPI.Datatype.fromcode(code)
                self.assertEqual(dtype.typechar, mpitype.typechar)
                self.assertEqual(dtype.typestr, mpitype.typestr)
                try:
                    mpitypedup1 = mpitype.Dup()
                    self.assertEqual(mpitypedup1.tocode(), mpitype.tocode())
                    self.assertEqual(mpitypedup1.typestr, mpitype.typestr)
                    self.assertEqual(mpitypedup1.typechar, mpitype.typechar)
                    mpitypedup2 = mpitypedup1.Dup()
                    self.assertEqual(mpitypedup2.tocode(), mpitype.tocode())
                    self.assertEqual(mpitypedup2.typestr, mpitype.typestr)
                    self.assertEqual(mpitypedup2.typechar, mpitype.typechar)
                finally:
                    mpitypedup1.Free()
                    mpitypedup2.Free()
        with self.assertRaises(ValueError):
            MPI.Datatype.fromcode("abc@xyz")
        with self.assertRaises(ValueError):
            MPI.DATATYPE_NULL.tocode()
        with self.assertRaises(ValueError):
            MPI.INT_INT.tocode()
        self.assertEqual(MPI.INT_INT.typechar, "V")
        self.assertEqual(MPI.INT_INT.typestr, f"V{MPI.INT.extent * 2}")

    def testContiguousBigMPI(self):
        int_max = (1 << (struct.calcsize("i") * 8 - 1)) - 1
        mpitype = MPI.BYTE.Create_contiguous(int_max)
        self.assertEqual(mpitype.size, int_max)
        basetype, combiner, params = mpitype.decode()
        self.assertEqual(basetype, MPI.BYTE)
        self.assertEqual(combiner, "CONTIGUOUS")
        self.assertEqual(params, {"count": int_max})
        mpitype.Free()
        if MPI.Get_version() < (3, 1):
            return
        count = int_max + int_max // 1024
        mpitype = MPI.BYTE.Create_contiguous(count)
        self.assertEqual(mpitype.size, count)
        basetype, combiner, params = mpitype.decode()
        mpitype.Free()
        if MPI.Get_version() >= (4, 0) or unittest.is_mpi("impi(>2021.12.0)"):
            self.assertEqual(basetype, MPI.BYTE)
            self.assertEqual(combiner, "CONTIGUOUS")
            self.assertEqual(params, {"count": count})
        else:
            self.assertEqual(basetype, MPI.DATATYPE_NULL)
            self.assertEqual(combiner, "STRUCT")
            qtype, rtype = params["datatypes"]
            self.assertEqual(qtype.combiner, MPI.COMBINER_VECTOR)
            self.assertEqual(qtype.size, (count // int_max) * int_max)
            self.assertEqual(rtype.combiner, MPI.COMBINER_CONTIGUOUS)
            self.assertEqual(rtype.size, count % int_max)
            qtype.free()
            rtype.free()


class BaseTestDatatypeCreateMixin:
    #
    def free(self, newtype):
        if newtype == MPI.DATATYPE_NULL:
            return
        *_, combiner = newtype.Get_envelope()
        if combiner in (
            MPI.COMBINER_NAMED,
            MPI.COMBINER_F90_INTEGER,
            MPI.COMBINER_F90_REAL,
            MPI.COMBINER_F90_COMPLEX,
        ):
            return
        newtype.Free()

    def check_contents(self, _factory, newtype, _oldtype):
        try:
            envelope = newtype.Get_envelope()
            contents = newtype.Get_contents()
        except NotImplementedError:
            self.skipTest("mpi-type-get_envelope")
        ni, na, nc, nd, combiner = envelope
        i, a, c, d = contents
        self.assertEqual(ni, len(i))
        self.assertEqual(na, len(a))
        self.assertEqual(nc, len(c))
        self.assertEqual(nd, len(d))
        self.assertNotEqual(combiner, MPI.COMBINER_NAMED)
        self.assertEqual(newtype.envelope, envelope)
        self.assertEqual(newtype.combiner, combiner)
        self.assertFalse(newtype.is_named)
        if combiner in (
            MPI.COMBINER_F90_INTEGER,
            MPI.COMBINER_F90_REAL,
            MPI.COMBINER_F90_COMPLEX,
        ):
            self.assertTrue(newtype.is_predefined)
        else:
            self.assertFalse(newtype.is_predefined)
        for dt in d:
            self.free(dt)
        contents = newtype.contents
        self.assertEqual(contents[:-1], (i, a, c))
        for dt in contents[-1]:
            self.free(dt)

    def check_recreate(self, factory, newtype):
        name = factory.__name__
        name = name.replace("Get_value_index", "Create_value_index")
        NAME = name.replace("Create_", "").upper()
        symbol = getattr(MPI, "COMBINER_" + NAME)
        if symbol == MPI.UNDEFINED:
            return
        if combiner_map is None:
            return
        symbol = combiner_map.get(symbol, symbol)
        if symbol is None:
            return
        self.assertEqual(symbol, newtype.combiner)
        decoded1 = newtype.decode()
        oldtype, constructor, kwargs = decoded1
        prefix = "create" if constructor != "VALUE_INDEX" else "get"
        constructor = prefix.title() + "_" + constructor.lower()
        newtype2 = getattr(oldtype, constructor)(**kwargs)
        decoded2 = newtype2.decode()
        types1 = decoded1[2].pop("datatypes", [])
        types2 = decoded2[2].pop("datatypes", [])
        for dt1, dt2 in zip(types1, types2):
            self.assertEqual(dt1.combiner, dt2.combiner)
            self.assertEqual(dt1.typechar, dt2.typechar)
            self.assertEqual(dt1.typestr, dt2.typestr)
            self.free(dt1)
            self.free(dt2)
        self.assertEqual(decoded1[1], decoded2[1])
        self.assertEqual(decoded2[2], decoded2[2])
        for dec in (decoded1, decoded2):
            self.free(dec[0])
        self.free(newtype2)

    def testDup(self):
        for dtype in datatypes:
            factory = MPI.Datatype.Dup
            self.check(dtype, factory)

    def testContiguous(self):
        for dtype in datatypes:
            for count in range(5):
                factory = MPI.Datatype.Create_contiguous
                args = (count,)
                self.check(dtype, factory, *args)

    def testVector(self):
        for dtype in datatypes:
            for count in range(5):
                for blocklength in range(5):
                    for stride in range(5):
                        factory = MPI.Datatype.Create_vector
                        args = (count, blocklength, stride)
                        self.check(dtype, factory, *args)

    def testHvector(self):
        for dtype in datatypes:
            for count in range(5):
                for blocklength in range(5):
                    for stride in range(5):
                        factory = MPI.Datatype.Create_hvector
                        args = (count, blocklength, stride)
                        self.check(dtype, factory, *args)

    def testIndexed(self):
        for dtype in datatypes:
            for block in range(5):
                blocklengths = list(range(block, block + 5))
                displacements = [0]
                for b in blocklengths[:-1]:
                    stride = displacements[-1] + b * dtype.extent + 1
                    displacements.append(stride)
                factory = MPI.Datatype.Create_indexed
                args = (blocklengths, displacements)
                self.check(dtype, factory, *args)
                # args = (block, displacements) XXX
                # self.check(dtype, factory, *args)  XXX

    def testIndexedBlock(self):
        for dtype in datatypes:
            for block in range(5):
                blocklengths = list(range(block, block + 5))
                displacements = [0]
                for b in blocklengths[:-1]:
                    stride = displacements[-1] + b * dtype.extent + 1
                    displacements.append(stride)
                factory = MPI.Datatype.Create_indexed_block
                args = (block, displacements)
                self.check(dtype, factory, *args)

    def testHindexed(self):
        for dtype in datatypes:
            for block in range(5):
                blocklengths = list(range(block, block + 5))
                displacements = [0]
                for b in blocklengths[:-1]:
                    stride = displacements[-1] + b * dtype.extent + 1
                    displacements.append(stride)

                factory = MPI.Datatype.Create_hindexed
                args = (blocklengths, displacements)
                self.check(dtype, factory, *args)
                # args = (block, displacements) XXX
                # self.check(dtype, factory, *args)  XXX

    @unittest.skipMPI("openmpi(<=1.8.1)", MPI.Get_version()[0] == 3)
    def testHindexedBlock(self):
        for dtype in datatypes:
            for block in range(5):
                displacements = [0]
                for _i in range(5):
                    stride = displacements[-1] + block * dtype.extent + 1
                    displacements.append(stride)
                factory = MPI.Datatype.Create_hindexed_block
                args = (block, displacements)
                self.check(dtype, factory, *args)

    def testStruct(self):
        for dtype1 in datatypes:
            for dtype2 in datatypes:
                dtypes = (dtype1, dtype2)
                blocklengths = (2, 3)
                displacements = [0]
                for dtype in dtypes[:-1]:
                    stride = displacements[-1] + dtype.extent
                    displacements.append(stride)
                factory = MPI.Datatype.Create_struct
                args = (blocklengths, displacements, dtypes)
                self.check(None, factory, *args)
        for dtype in datatypes:
            factory = MPI.Datatype.Create_struct
            dtypes = [dtype.Dup()]
            dtypes.append(dtypes[-1].Create_contiguous(2))
            dtypes.append(dtypes[-1].Dup())
            dtypes.append(dtypes[-1].Create_struct([1], [0], [dtypes[-1]]))
            dtypes.append(dtypes[-1].Dup())
            dtypes.append(dtypes[-1].Create_resized(0, dtypes[-1].extent))
            dtypes.append(dtypes[-1].Dup())
            for dt in dtypes:
                args = [[1, 1], [0, dt.extent * 2], (dt, dt)]
                self.check(None, factory, *args)
                dt.Free()
        with self.assertRaises(ValueError):
            factory = MPI.Datatype.Create_struct
            factory([1], [0], [MPI.INT, MPI.FLOAT])

    def testSubarray(self):
        for dtype in datatypes:
            for ndim in range(1, 5):
                for size in range(1, 5):
                    for subsize in range(1, size):
                        for start in range(size - subsize):
                            for order in [
                                MPI.ORDER_C,
                                MPI.ORDER_FORTRAN,
                                MPI.ORDER_F,
                            ]:
                                sizes = [size] * ndim
                                subsizes = [subsize] * ndim
                                starts = [start] * ndim
                                factory = MPI.Datatype.Create_subarray
                                args = sizes, subsizes, starts, order
                                self.check(dtype, factory, *args)

    def testDarray(self):
        for dtype in datatypes:
            for ndim in range(1, 3 + 1):
                for size in (4, 8, 9, 27):
                    for rank in (0, size - 1):
                        for dist in [
                            MPI.DISTRIBUTE_BLOCK,
                            MPI.DISTRIBUTE_CYCLIC,
                        ]:
                            for order in [MPI.ORDER_C, MPI.ORDER_F]:
                                gsizes = [size] * ndim
                                distribs = [dist] * ndim
                                dargs = [MPI.DISTRIBUTE_DFLT_DARG] * ndim
                                psizes = MPI.Compute_dims(size, [0] * ndim)
                                factory = MPI.Datatype.Create_darray
                                args = (
                                    size,
                                    rank,
                                    gsizes,
                                    distribs,
                                    dargs,
                                    psizes,
                                    order,
                                )
                                self.check(dtype, factory, *args)

    def testF90Integer(self):
        for r in (1, 2, 4):
            factory = MPI.Datatype.Create_f90_integer
            args = (r,)
            self.check(None, factory, *args)

    @unittest.skipMPI("openmpi(<3.0.0)")
    @unittest.skipMPI("msmpi")
    def testF90RealSingle(self):
        (p, r) = (6, 30)
        factory = MPI.Datatype.Create_f90_real
        args = (p, r)
        self.check(None, factory, *args)

    @unittest.skipMPI("openmpi(<3.0.0)")
    @unittest.skipMPI("msmpi")
    @unittest.skipIf('ppc' in test_machine, "testF90RealDouble fails on ppc arches")
    def testF90RealDouble(self):
        (p, r) = (15, 300)
        factory = MPI.Datatype.Create_f90_real
        args = (p, r)
        self.check(None, factory, *args)

    @unittest.skipMPI("openmpi(<3.0.0)")
    @unittest.skipMPI("msmpi")
    def testF90ComplexSingle(self):
        (p, r) = (6, 30)
        factory = MPI.Datatype.Create_f90_complex
        args = (p, r)
        self.check(None, factory, *args)

    @unittest.skipMPI("openmpi(<3.0.0)")
    @unittest.skipMPI("msmpi")
    @unittest.skipIf('ppc' in test_machine, "testF90ComplexDouble fails on ppc arches")
    def testF90ComplexDouble(self):
        (p, r) = (15, 300)
        factory = MPI.Datatype.Create_f90_complex
        args = (p, r)
        self.check(None, factory, *args)

    def testResized(self):
        for dtype in datatypes:
            for lb in range(-10, 10):
                for extent in range(1, 10):
                    factory = MPI.Datatype.Create_resized
                    args = lb, extent
                    self.check(dtype, factory, *args)

    def testValueIndex(self):
        integral_types = datatypes_c[2:-3] + datatypes_c99[1:9]
        floating_types = datatypes_c[-3:]
        value_types = integral_types + floating_types
        index_types = integral_types
        for value in value_types:
            if value == MPI.DATATYPE_NULL:
                continue
            for index in index_types:
                if index == MPI.DATATYPE_NULL:
                    continue
                factory = MPI.Datatype.Get_value_index
                pair = factory(value, index)
                if pair == MPI.DATATYPE_NULL:
                    continue
                if pair.is_named:
                    continue
                self.check(None, factory, value, index)


class TestDatatypeCreate(BaseTestDatatypeCreateMixin, unittest.TestCase):
    #
    def check(self, oldtype, factory, *args):
        try:
            if oldtype is not None:
                newtype = factory(oldtype, *args)
            else:
                newtype = factory(*args)
            if newtype == MPI.DATATYPE_NULL:
                return
        except NotImplementedError:
            self.skipTest("mpi-type-constructor")
        self.check_contents(factory, newtype, oldtype)
        self.check_recreate(factory, newtype)
        newtype.Commit()
        self.check_contents(factory, newtype, oldtype)
        self.check_recreate(factory, newtype)
        self.free(newtype)


class TestDatatypePickle(BaseTestDatatypeCreateMixin, unittest.TestCase):
    #
    def check(self, oldtype, factory, *args):
        from pickle import dumps, loads

        try:
            if oldtype is not None:
                newtype0 = factory(oldtype, *args)
            else:
                newtype0 = factory(*args)
            if newtype0 == MPI.DATATYPE_NULL:
                return
        except NotImplementedError:
            self.skipTest("mpi-type-constructor")
        newtype1 = loads(dumps(newtype0))
        self.check_contents(factory, newtype1, oldtype)
        self.free(newtype1)
        self.free(newtype0)

    def testNamed(self):
        from pickle import dumps, loads

        for dtype in [MPI.DATATYPE_NULL, *datatypes]:
            newdtype = loads(dumps(dtype))
            self.assertIs(newdtype, dtype)
            newdtype = loads(dumps(MPI.Datatype(dtype)))
            self.assertIsNot(newdtype, dtype)
            self.assertEqual(newdtype, dtype)


name, version = MPI.get_vendor()
if name == "MPICH1":
    combiner_map[MPI.COMBINER_VECTOR] = None
    combiner_map[MPI.COMBINER_HVECTOR] = None
    combiner_map[MPI.COMBINER_INDEXED] = None
    combiner_map[MPI.COMBINER_HINDEXED_BLOCK] = None
    for t in datatypes_f77:
        if t in datatypes:
            datatypes.remove(t)
    datatypes_f77.clear()
elif MPI.Get_version() < (2, 0):
    combiner_map = None
if name == "Open MPI":
    if (1, 6, 0) < version < (1, 7, 0):
        TestDatatype.match_size_complex[:] = []
    if version < (1, 5, 2):
        for t in [getattr(MPI, f"COMPLEX{i}") for i in (4, 8, 16, 32)]:
            if t in datatypes:
                datatypes.remove(t)
            if t in datatypes_f90:
                datatypes_f90.remove(t)
if name in ("MPICH", "Intel MPI"):
    if MPI.FLOAT16_T != MPI.DATATYPE_NULL:
        if MPI.FLOAT16_T.Get_name() == "":
            MPI.FLOAT16_T.Set_name("MPIX_C_FLOAT16")


if __name__ == "__main__":
    unittest.main()
