import petsc4py
from petsc4py import PETSc
import unittest
import os
import filecmp
import numpy as np
import importlib

# --------------------------------------------------------------------

ERR_ARG_OUTOFRANGE = 63


class BaseTestPlex:
    COMM = PETSc.COMM_WORLD
    DIM = 1
    CELLS = [[0, 1], [1, 2]]
    COORDS = [[0.0], [0.5], [1.0]]
    COMP = 1
    DOFS = [1, 0]

    def setUp(self):
        self.plex = PETSc.DMPlex().createFromCellList(
            self.DIM, self.CELLS, self.COORDS, comm=self.COMM
        )

    def tearDown(self):
        self.plex.destroy()
        self.plex = None
        PETSc.garbage_cleanup()

    def testTopology(self):
        rank = self.COMM.rank
        dim = self.plex.getDimension()
        pStart, pEnd = self.plex.getChart()
        cStart, cEnd = self.plex.getHeightStratum(0)
        vStart, vEnd = self.plex.getDepthStratum(0)
        numDepths = self.plex.getLabelSize('depth')
        coords_raw = self.plex.getCoordinates().getArray()
        coords = np.reshape(coords_raw, (vEnd - vStart, dim))
        self.assertEqual(dim, self.DIM)
        self.assertEqual(numDepths, self.DIM + 1)
        if rank == 0 and self.CELLS is not None:
            self.assertEqual(cEnd - cStart, len(self.CELLS))
        if rank == 0 and self.COORDS is not None:
            self.assertEqual(vEnd - vStart, len(self.COORDS))
            self.assertTrue((coords == self.COORDS).all())

    def testClosure(self):
        pStart, pEnd = self.plex.getChart()
        for p in range(pStart, pEnd):
            closure = self.plex.getTransitiveClosure(p)[0]
            for c in closure:
                cone = self.plex.getCone(c)
                self.assertEqual(self.plex.getConeSize(c), len(cone))
                for i in cone:
                    self.assertIn(i, closure)
            star = self.plex.getTransitiveClosure(p, useCone=False)[0]
            for s in star:
                support = self.plex.getSupport(s)
                self.assertEqual(self.plex.getSupportSize(s), len(support))
                for i in support:
                    self.assertIn(i, star)

    def testAdjacency(self):
        PETSc.DMPlex.setAdjacencyUseAnchors(self.plex, False)
        flag = PETSc.DMPlex.getAdjacencyUseAnchors(self.plex)
        self.assertFalse(flag)
        PETSc.DMPlex.setAdjacencyUseAnchors(self.plex, True)
        flag = PETSc.DMPlex.getAdjacencyUseAnchors(self.plex)
        self.assertTrue(flag)
        PETSc.DMPlex.setBasicAdjacency(self.plex, False, False)
        flagA, flagB = PETSc.DMPlex.getBasicAdjacency(self.plex)
        self.assertFalse(flagA)
        self.assertFalse(flagB)
        PETSc.DMPlex.setBasicAdjacency(self.plex, True, True)
        flagA, flagB = PETSc.DMPlex.getBasicAdjacency(self.plex)
        self.assertTrue(flagA)
        self.assertTrue(flagB)
        pStart, pEnd = self.plex.getChart()
        for p in range(pStart, pEnd):
            adjacency = self.plex.getAdjacency(p)
            self.assertTrue(p in adjacency)
            self.assertTrue(len(adjacency) > 1)

    def testSectionDofs(self):
        self.plex.setNumFields(1)
        section = self.plex.createSection([self.COMP], [self.DOFS])
        size = section.getStorageSize()
        entity_dofs = [
            self.plex.getStratumSize('depth', d) * self.DOFS[d]
            for d in range(self.DIM + 1)
        ]
        self.assertEqual(sum(entity_dofs), size)

    def testSectionClosure(self):
        section = self.plex.createSection([self.COMP], [self.DOFS])
        self.plex.setSection(section)
        vec = self.plex.createLocalVec()
        pStart, pEnd = self.plex.getChart()
        for p in range(pStart, pEnd):
            for i in range(section.getDof(p)):
                off = section.getOffset(p)
                vec.setValue(off + i, p)

        for p in range(pStart, pEnd):
            point_closure = self.plex.getTransitiveClosure(p)[0]
            dof_closure = self.plex.vecGetClosure(section, vec, p)
            for p in dof_closure:
                self.assertIn(p, point_closure)

    def testBoundaryLabel(self):
        pStart, pEnd = self.plex.getChart()
        if pEnd - pStart == 0:
            return

        self.assertFalse(self.plex.hasLabel('boundary'))
        self.plex.markBoundaryFaces('boundary')
        self.assertTrue(self.plex.hasLabel('boundary'))

        faces = self.plex.getStratumIS('boundary', 1)
        for f in faces.getIndices():
            points, orient = self.plex.getTransitiveClosure(f, useCone=True)
            for p in points:
                self.plex.setLabelValue('boundary', p, 1)

        for p in range(pStart, pEnd):
            if self.plex.getLabelValue('boundary', p) != 1:
                self.plex.setLabelValue('boundary', p, 2)

        numBoundary = self.plex.getStratumSize('boundary', 1)
        numInterior = self.plex.getStratumSize('boundary', 2)
        self.assertNotEqual(numBoundary, pEnd - pStart)
        self.assertNotEqual(numInterior, pEnd - pStart)
        self.assertEqual(numBoundary + numInterior, pEnd - pStart)

    def testMetric(self):
        if self.DIM == 1:
            return
        self.plex.distribute()
        if self.CELLS is None and not self.plex.isSimplex():
            return
        self.plex.orient()

        h_min = 1.0e-30
        h_max = 1.0e30
        a_max = 1.0e10
        target = 8.0
        p = 1.0
        beta = 1.3
        hausd = 0.01
        self.plex.metricSetUniform(False)
        self.plex.metricSetIsotropic(False)
        self.plex.metricSetRestrictAnisotropyFirst(False)
        self.plex.metricSetNoInsertion(False)
        self.plex.metricSetNoSwapping(False)
        self.plex.metricSetNoMovement(False)
        self.plex.metricSetNoSurf(False)
        self.plex.metricSetVerbosity(-1)
        self.plex.metricSetNumIterations(3)
        self.plex.metricSetMinimumMagnitude(h_min)
        self.plex.metricSetMaximumMagnitude(h_max)
        self.plex.metricSetMaximumAnisotropy(a_max)
        self.plex.metricSetTargetComplexity(target)
        self.plex.metricSetNormalizationOrder(p)
        self.plex.metricSetGradationFactor(beta)
        self.plex.metricSetHausdorffNumber(hausd)

        self.assertFalse(self.plex.metricIsUniform())
        self.assertFalse(self.plex.metricIsIsotropic())
        self.assertFalse(self.plex.metricRestrictAnisotropyFirst())
        self.assertFalse(self.plex.metricNoInsertion())
        self.assertFalse(self.plex.metricNoSwapping())
        self.assertFalse(self.plex.metricNoMovement())
        self.assertFalse(self.plex.metricNoSurf())
        self.assertTrue(self.plex.metricGetVerbosity() == -1)
        self.assertTrue(self.plex.metricGetNumIterations() == 3)
        self.assertTrue(np.isclose(self.plex.metricGetMinimumMagnitude(), h_min))
        self.assertTrue(np.isclose(self.plex.metricGetMaximumMagnitude(), h_max))
        self.assertTrue(np.isclose(self.plex.metricGetMaximumAnisotropy(), a_max))
        self.assertTrue(np.isclose(self.plex.metricGetTargetComplexity(), target))
        self.assertTrue(np.isclose(self.plex.metricGetNormalizationOrder(), p))
        self.assertTrue(np.isclose(self.plex.metricGetGradationFactor(), beta))
        self.assertTrue(np.isclose(self.plex.metricGetHausdorffNumber(), hausd))

        metric1 = self.plex.metricCreateUniform(0.5)
        metric2 = self.plex.metricCreateUniform(1.0)
        metric = self.plex.metricCreate()
        det = self.plex.metricDeterminantCreate()
        self.plex.metricAverage2(metric1, metric2, metric)
        metric1.array[:] *= 1.5
        self.assertTrue(np.allclose(metric.array, metric1.array))
        self.plex.metricIntersection2(metric1, metric2, metric)
        self.assertTrue(np.allclose(metric.array, metric2.array))
        self.plex.metricEnforceSPD(metric, metric1, det[0])
        self.assertTrue(np.allclose(metric.array, metric1.array))

        if self.DIM == 2 and PETSc.COMM_WORLD.getSize() > 6:
            # Error with 7 processes in 2D: normalization factor is -1
            return

        self.plex.metricNormalize(
            metric, metric1, det[0], restrictSizes=False, restrictAnisotropy=False
        )
        metric2.scale(pow(target, 2.0 / self.DIM))
        self.assertTrue(np.allclose(metric1.array, metric2.array))

    def testAdapt(self):
        if self.DIM == 1:
            return
        if self.DIM == 3 and PETSc.COMM_WORLD.getSize() > 4:
            # Error with 5 processes in 3D
            # ----------------------------
            # Warning: MMG5_mmgIntextmet: Unable to diagonalize at least 1 metric.
            # Error: MMG3D_defsiz_ani: unable to intersect metrics at point 8.
            # Metric undefined. Exit program.
            # MMG remeshing problem. Exit program.
            return
        self.plex.orient()
        plex = self.plex.refine()
        plex.distribute()
        if self.CELLS is None and not plex.isSimplex():
            return
        if sum(self.DOFS) > 1:
            return
        metric = plex.metricCreateUniform(9.0)
        try:
            newplex = plex.adaptMetric(metric, '')
            plex.destroy()
            newplex.destroy()
        except PETSc.Error as exc:
            plex.destroy()
            if exc.ierr != ERR_ARG_OUTOFRANGE:
                raise

    def testNatural(self):
        dim = self.plex.getDimension()
        ct = self.plex.getCellType(0)
        fe = PETSc.FE().createByCell(dim, 1, ct)
        self.plex.setField(0, fe)
        self.plex.createDS()
        self.plex.setUseNatural(True)
        self.plex.distribute()
        self.plex.view()
        gv = self.plex.createGlobalVec()
        nv = self.plex.createNaturalVec()
        self.plex.globalToNaturalBegin(gv, nv)
        self.plex.globalToNaturalEnd(gv, nv)
        self.plex.naturalToGlobalBegin(nv, gv)
        self.plex.naturalToGlobalEnd(nv, gv)

# --------------------------------------------------------------------


class BaseTestPlex_2D(BaseTestPlex):
    DIM = 2
    CELLS = [
        [0, 1, 3],
        [1, 3, 4],
        [1, 2, 4],
        [2, 4, 5],
        [3, 4, 6],
        [4, 6, 7],
        [4, 5, 7],
        [5, 7, 8],
    ]
    COORDS = [
        [0.0, 0.0],
        [0.5, 0.0],
        [1.0, 0.0],
        [0.0, 0.5],
        [0.5, 0.5],
        [1.0, 0.5],
        [0.0, 1.0],
        [0.5, 1.0],
        [1.0, 1.0],
    ]
    DOFS = [1, 0, 0]


class BaseTestPlex_3D(BaseTestPlex):
    DIM = 3
    CELLS = [
        [0, 2, 3, 7],
        [0, 2, 6, 7],
        [0, 4, 6, 7],
        [0, 1, 3, 7],
        [0, 1, 5, 7],
        [0, 4, 5, 7],
    ]
    COORDS = [
        [0.0, 0.0, 0.0],
        [1.0, 0.0, 0.0],
        [0.0, 1.0, 0.0],
        [1.0, 1.0, 0.0],
        [0.0, 0.0, 1.0],
        [1.0, 0.0, 1.0],
        [0.0, 1.0, 1.0],
        [1.0, 1.0, 1.0],
    ]
    DOFS = [1, 0, 0, 0]


# --------------------------------------------------------------------


class TestPlex_1D(BaseTestPlex, unittest.TestCase):
    pass


class TestPlex_2D(BaseTestPlex_2D, unittest.TestCase):
    def testTransform(self):
        plex = self.plex
        cstart, cend = plex.getHeightStratum(0)
        tr = PETSc.DMPlexTransform().create(comm=PETSc.COMM_WORLD)
        tr.setType(PETSc.DMPlexTransformType.REFINEALFELD)
        tr.setDM(plex)
        tr.setUp()
        newplex = tr.apply(plex)
        tr.destroy()
        newcstart, newcend = newplex.getHeightStratum(0)
        newplex.destroy()
        self.assertTrue((newcend - newcstart) == 3 * (cend - cstart))


class TestPlex_3D(BaseTestPlex_3D, unittest.TestCase):
    pass


class TestPlex_2D_P3(BaseTestPlex_2D, unittest.TestCase):
    DOFS = [1, 2, 1]


class TestPlex_3D_P3(BaseTestPlex_3D, unittest.TestCase):
    DOFS = [1, 2, 1, 0]


class TestPlex_3D_P4(BaseTestPlex_3D, unittest.TestCase):
    DOFS = [1, 3, 3, 1]


class TestPlex_2D_BoxTensor(BaseTestPlex_2D, unittest.TestCase):
    CELLS = None
    COORDS = None

    def setUp(self):
        self.plex = PETSc.DMPlex().createBoxMesh([3, 3], simplex=False)


class TestPlex_3D_BoxTensor(BaseTestPlex_3D, unittest.TestCase):
    CELLS = None
    COORDS = None

    def setUp(self):
        self.plex = PETSc.DMPlex().createBoxMesh([3, 3, 3], simplex=False)


# FIXME
try:
    raise PETSc.Error
    PETSc.DMPlex().createBoxMesh([2, 2], simplex=True, comm=PETSc.COMM_SELF).destroy()
except PETSc.Error:
    pass
else:

    class TestPlex_2D_Box(BaseTestPlex_2D, unittest.TestCase):
        CELLS = None
        COORDS = None

        def setUp(self):
            self.plex = PETSc.DMPlex().createBoxMesh([1, 1], simplex=True)

    class TestPlex_2D_Boundary(BaseTestPlex_2D, unittest.TestCase):
        CELLS = None
        COORDS = None

        def setUp(self):
            boundary = PETSc.DMPlex().create(self.COMM)
            boundary.createSquareBoundary([0.0, 0.0], [1.0, 1.0], [2, 2])
            boundary.setDimension(self.DIM - 1)
            self.plex = PETSc.DMPlex().generate(boundary)

    class TestPlex_3D_Box(BaseTestPlex_3D, unittest.TestCase):
        CELLS = None
        COORDS = None

        def setUp(self):
            self.plex = PETSc.DMPlex().createBoxMesh([1, 1, 1], simplex=True)

    class TestPlex_3D_Boundary(BaseTestPlex_3D, unittest.TestCase):
        CELLS = None
        COORDS = None

        def setUp(self):
            boundary = PETSc.DMPlex().create(self.COMM)
            boundary.createCubeBoundary([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1, 1, 1])
            boundary.setDimension(self.DIM - 1)
            self.plex = PETSc.DMPlex().generate(boundary)

# --------------------------------------------------------------------

PETSC_DIR = petsc4py.get_config()['PETSC_DIR']


def check_dtype(method):
    def wrapper(self, *args, **kwargs):
        if PETSc.ScalarType is PETSc.ComplexType:
            return None
        return method(self, *args, **kwargs)

    return wrapper


def check_package(method):
    def wrapper(self, *args, **kwargs):
        if not PETSc.Sys.hasExternalPackage('hdf5'):
            return None
        if self.PARTITIONERTYPE != 'simple' and not PETSc.Sys.hasExternalPackage(
            self.PARTITIONERTYPE
        ):
            return None
        return method(self, *args, **kwargs)

    return wrapper


def check_nsize(method):
    def wrapper(self, *args, **kwargs):
        if PETSc.COMM_WORLD.size != self.NSIZE:
            return None
        return method(self, *args, **kwargs)

    return wrapper


class BaseTestPlexHDF5:
    NSIZE = 4
    NTIMES = 3

    def tearDown(self):
        if not PETSc.COMM_WORLD.rank:
            if os.path.exists(self.outfile()):
                os.remove(self.outfile())
            if os.path.exists(self.tmp_output_file()):
                os.remove(self.tmp_output_file())

    def _name(self):
        return f'{self.SUFFIX}_outformat-{self.OUTFORMAT}_{self.PARTITIONERTYPE}'

    def infile(self):
        return os.path.join(
            PETSC_DIR, 'share/petsc/datafiles/', 'meshes/blockcylinder-50.h5'
        )

    def outfile(self):
        return os.path.join('./temp_test_dmplex_%s.h5' % self._name())

    def informat(self):
        return PETSc.Viewer.Format.HDF5_XDMF

    def outformat(self):
        d = {
            'hdf5_petsc': PETSc.Viewer.Format.HDF5_PETSC,
            'hdf5_xdmf': PETSc.Viewer.Format.HDF5_XDMF,
        }
        return d[self.OUTFORMAT]

    def partitionerType(self):
        d = {
            'simple': PETSc.Partitioner.Type.SIMPLE,
            'ptscotch': PETSc.Partitioner.Type.PTSCOTCH,
            'parmetis': PETSc.Partitioner.Type.PARMETIS,
        }
        return d[self.PARTITIONERTYPE]

    def ref_output_file(self):
        return os.path.join(
            PETSC_DIR,
            'src/dm/impls/plex/tutorials/',
            'output/ex5_%s.out' % self._name(),
        )

    def tmp_output_file(self):
        return os.path.join('./temp_test_dmplex_%s.out' % self._name())

    def outputText(self, msg, comm):
        if not comm.rank:
            with open(self.tmp_output_file(), 'a') as f:
                f.write(msg)

    def outputPlex(self, plex):
        txtvwr = PETSc.Viewer().createASCII(
            self.tmp_output_file(), mode='a', comm=plex.comm
        )
        plex.view(viewer=txtvwr)
        txtvwr.destroy()

    @check_dtype
    @check_package
    @check_nsize
    def testViewLoadCycle(self):
        if importlib.util.find_spec('mpi4py') is None:
            self.skipTest('mpi4py')  # throws special exception to signal test skip
        grank = PETSc.COMM_WORLD.rank
        for i in range(self.NTIMES):
            if i == 0:
                infname = self.infile()
                informt = self.informat()
            else:
                infname = self.outfile()
                informt = self.outformat()
            if self.HETEROGENEOUS:
                mycolor = grank > self.NTIMES - i
            else:
                mycolor = 0
            mpicomm = PETSc.COMM_WORLD.tompi4py()
            comm = PETSc.Comm(comm=mpicomm.Split(color=mycolor, key=grank))
            if mycolor == 0:
                self.outputText('Begin cycle %d\n' % i, comm)
                plex = PETSc.DMPlex()
                vwr = PETSc.ViewerHDF5()
                # Create plex
                plex.create(comm=comm)
                plex.setName('DMPlex Object')
                # Load data from XDMF into dm in parallel
                vwr.create(infname, mode='r', comm=comm)
                vwr.pushFormat(format=informt)
                plex.load(viewer=vwr)
                plex.setOptionsPrefix('loaded_')
                plex.distributeSetDefault(False)
                plex.setFromOptions()
                vwr.popFormat()
                vwr.destroy()
                self.outputPlex(plex)
                # Test DM is indeed distributed
                flg = plex.isDistributed()
                self.outputText(
                    'Loaded mesh distributed? %s\n' % str(flg).upper(), comm
                )
                # Interpolate
                plex.interpolate()
                plex.setOptionsPrefix('interpolated_')
                plex.setFromOptions()
                self.outputPlex(plex)
                # Redistribute
                part = plex.getPartitioner()
                part.setType(self.partitionerType())
                sf = plex.distribute(overlap=0)
                if sf:
                    sf.destroy()
                part.destroy()
                plex.setName('DMPlex Object')
                plex.setOptionsPrefix('redistributed_')
                plex.setFromOptions()
                self.outputPlex(plex)
                # Save redistributed dm to XDMF in parallel
                vwr.create(self.outfile(), mode='w', comm=comm)
                vwr.pushFormat(format=self.outformat())
                plex.setName('DMPlex Object')
                plex.view(viewer=vwr)
                vwr.popFormat()
                vwr.destroy()
                # Destroy plex
                plex.destroy()
                self.outputText('End   cycle %d\n--------\n' % i, comm)
            comm.tompi4py().Free()
            PETSc.COMM_WORLD.Barrier()
        # Check that the output is identical to that of plex/tutorial/ex5.c.
        self.assertTrue(
            filecmp.cmp(self.tmp_output_file(), self.ref_output_file(), shallow=False),
            f'Contents of the files not the same. Reference file: {self.ref_output_file()}',
        )
        PETSc.COMM_WORLD.Barrier()


class BaseTestPlexHDF5Homogeneous(BaseTestPlexHDF5):
    """Test save on N / load on N."""

    SUFFIX = 0
    HETEROGENEOUS = False


class BaseTestPlexHDF5Heterogeneous(BaseTestPlexHDF5):
    """Test save on N / load on M."""

    SUFFIX = 1
    HETEROGENEOUS = True


class TestPlexHDF5PETSCSimpleHomogeneous(
    BaseTestPlexHDF5Homogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_petsc'
    PARTITIONERTYPE = 'simple'


"""
Skipping. PTScotch produces different distributions when run
in a sequence in a single session.

class TestPlexHDF5PETSCPTScotchHomogeneous(BaseTestPlexHDF5Homogeneous,
                                           unittest.TestCase):
    OUTFORMAT = "hdf5_petsc"
    PARTITIONERTYPE = "ptscotch"
"""


class TestPlexHDF5PETSCParmetisHomogeneous(
    BaseTestPlexHDF5Homogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_petsc'
    PARTITIONERTYPE = 'parmetis'


class TestPlexHDF5XDMFSimpleHomogeneous(BaseTestPlexHDF5Homogeneous, unittest.TestCase):
    OUTFORMAT = 'hdf5_xdmf'
    PARTITIONERTYPE = 'simple'


"""
Skipping. PTScotch produces different distributions when run
in a sequence in a single session.

class TestPlexHDF5XDMFPTScotchHomogeneous(BaseTestPlexHDF5Homogeneous,
                                          unittest.TestCase):
    OUTFORMAT = "hdf5_xdmf"
    PARTITIONERTYPE = "ptscotch"
"""


class TestPlexHDF5XDMFParmetisHomogeneous(
    BaseTestPlexHDF5Homogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_xdmf'
    PARTITIONERTYPE = 'parmetis'


class TestPlexHDF5PETSCSimpleHeterogeneous(
    BaseTestPlexHDF5Heterogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_petsc'
    PARTITIONERTYPE = 'simple'


"""
Skipping. PTScotch produces different distributions when run
in a sequence in a single session.

class TestPlexHDF5PETSCPTScotchHeterogeneous(BaseTestPlexHDF5Heterogeneous,
                                             unittest.TestCase):
    OUTFORMAT = "hdf5_petsc"
    PARTITIONERTYPE = "ptscotch"
"""


class TestPlexHDF5PETSCParmetisHeterogeneous(
    BaseTestPlexHDF5Heterogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_petsc'
    PARTITIONERTYPE = 'parmetis'


class TestPlexHDF5XDMFSimpleHeterogeneous(
    BaseTestPlexHDF5Heterogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_xdmf'
    PARTITIONERTYPE = 'simple'


class TestPlexHDF5XDMFPTScotchHeterogeneous(
    BaseTestPlexHDF5Heterogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_xdmf'
    PARTITIONERTYPE = 'ptscotch'


class TestPlexHDF5XDMFParmetisHeterogeneous(
    BaseTestPlexHDF5Heterogeneous, unittest.TestCase
):
    OUTFORMAT = 'hdf5_xdmf'
    PARTITIONERTYPE = 'parmetis'


# --------------------------------------------------------------------

if __name__ == '__main__':
    unittest.main()
