from petsc4py import PETSc
import unittest
import numpy as np


class TestDMShell(unittest.TestCase):
    COMM = PETSc.COMM_WORLD

    def setUp(self):
        self.dm = PETSc.DMShell().create(comm=self.COMM)

    def tearDown(self):
        self.dm.destroy()
        self.dm = None
        PETSc.garbage_cleanup()

    def testSetGlobalVector(self):
        vec = PETSc.Vec().create(comm=self.COMM)
        vec.setSizes((10, None))
        vec.setUp()
        self.dm.setGlobalVector(vec)
        gvec = self.dm.createGlobalVector()
        self.assertEqual(vec.getSizes(), gvec.getSizes())
        self.assertEqual(vec.comm, gvec.comm)

    def testSetCreateGlobalVector(self):
        def create_vec(dm):
            v = PETSc.Vec().create(comm=dm.comm)
            v.setSizes((10, None))
            v.setUp()
            return v

        self.dm.setCreateGlobalVector(create_vec)
        gvec = self.dm.createGlobalVector()
        self.assertEqual(gvec.comm, self.dm.comm)
        self.assertEqual(gvec.getLocalSize(), 10)

    def testSetLocalVector(self):
        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
        vec.setSizes((1 + 10 * self.COMM.rank, None))
        vec.setUp()
        self.dm.setLocalVector(vec)
        lvec = self.dm.createLocalVector()
        self.assertEqual(vec.getSizes(), lvec.getSizes())
        lsize, gsize = lvec.getSizes()
        self.assertEqual(lsize, gsize)
        self.assertEqual(lvec.comm, PETSc.COMM_SELF)

    def testSetCreateLocalVector(self):
        def create_vec(dm):
            v = PETSc.Vec().create(comm=PETSc.COMM_SELF)
            v.setSizes((1 + 10 * dm.comm.rank, None))
            v.setUp()
            return v

        self.dm.setCreateLocalVector(create_vec)
        lvec = self.dm.createLocalVector()
        lsize, gsize = lvec.getSizes()
        self.assertEqual(lsize, gsize)
        self.assertEqual(lsize, 1 + 10 * self.dm.comm.rank)
        self.assertEqual(lvec.comm, PETSc.COMM_SELF)

    def testSetMatrix(self):
        mat = PETSc.Mat().create(comm=self.COMM)
        mat.setSizes(((10, None), (2, None)))
        mat.setUp()
        mat.assemble()
        self.dm.setMatrix(mat)
        nmat = self.dm.createMatrix()
        self.assertEqual(nmat.getSizes(), mat.getSizes())

    def testSetCreateMatrix(self):
        def create_mat(dm):
            mat = PETSc.Mat().create(comm=self.COMM)
            mat.setSizes(((10, None), (2, None)))
            mat.setUp()
            return mat

        self.dm.setCreateMatrix(create_mat)
        nmat = self.dm.createMatrix()
        self.assertEqual(nmat.getSizes(), create_mat(self.dm).getSizes())

    def testSetCreateFieldDecomposition(self):
        def create_field_decomposition_only_is(dm):
            return None, [PETSc.IS().createStride(1) for _ in range(2)], None

        def create_field_decomposition_only_is_names(dm):
            _, ises, _ = create_field_decomposition_only_is(dm)
            names = [f'f_{i}' for i in range(len(ises))]
            return names, ises, None

        def create_field_decomposition_only_is_dms(dm):
            _, ises, _ = create_field_decomposition_only_is(dm)
            dms = [dm.clone() for _ in range(len(ises))]
            return _, ises, dms

        def create_field_decomposition_only_full(dm):
            names, ises, _ = create_field_decomposition_only_is_names(dm)
            dms = [dm.clone() for _ in range(len(ises))]
            return names, ises, dms

        tests = [
            (create_field_decomposition_only_is, False, False),
            (create_field_decomposition_only_is_names, True, False),
            (create_field_decomposition_only_is_dms, False, True),
            (create_field_decomposition_only_full, True, True),
        ]
        for test_f, has_names, has_dms in tests:
            self.dm.setCreateFieldDecomposition(test_f)
            names, ises, dms = self.dm.createFieldDecomposition()
            self.assertEqual(len(names), len(ises))
            self.assertEqual(len(dms), len(ises))
            if has_names:
                checknames = [f'f_{i}' for i in range(len(ises))]
                self.assertEqual(names, checknames)
            else:
                for n in names:
                    self.assertTrue(n is None)
            for dm in dms:
                if has_dms:
                    self.assertEqual(type(dm), PETSc.DM)
                else:
                    self.assertTrue(dm is None)

    def testGlobalToLocal(self):
        def begin(dm, ivec, mode, ovec):
            if mode == PETSc.InsertMode.INSERT_VALUES:
                ovec[...] = ivec[...]
            elif mode == PETSc.InsertMode.ADD_VALUES:
                ovec[...] += ivec[...]

        def end(dm, ivec, mode, ovec):
            pass

        vec = PETSc.Vec().create(comm=self.COMM)
        vec.setSizes((10, None))
        vec.setUp()
        vec[...] = self.dm.comm.rank + 1
        ovec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
        ovec.setSizes((10, None))
        ovec.setUp()
        self.dm.setGlobalToLocal(begin, end)
        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))

    def testLocalToGlobal(self):
        def begin(dm, ivec, mode, ovec):
            if mode == PETSc.InsertMode.INSERT_VALUES:
                ovec[...] = ivec[...]
            elif mode == PETSc.InsertMode.ADD_VALUES:
                ovec[...] += ivec[...]

        def end(dm, ivec, mode, ovec):
            pass

        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
        vec.setSizes((10, None))
        vec.setUp()
        vec[...] = self.dm.comm.rank + 1
        ovec = PETSc.Vec().create(comm=self.COMM)
        ovec.setSizes((10, None))
        ovec.setUp()
        self.dm.setLocalToGlobal(begin, end)
        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))

    def testLocalToLocal(self):
        def begin(dm, ivec, mode, ovec):
            if mode == PETSc.InsertMode.INSERT_VALUES:
                ovec[...] = ivec[...]
            elif mode == PETSc.InsertMode.ADD_VALUES:
                ovec[...] += ivec[...]

        def end(dm, ivec, mode, ovec):
            pass

        vec = PETSc.Vec().create(comm=PETSc.COMM_SELF)
        vec.setSizes((10, None))
        vec.setUp()
        vec[...] = self.dm.comm.rank + 1
        ovec = vec.duplicate()
        self.dm.setLocalToLocal(begin, end)
        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)
        self.assertTrue(np.allclose(vec.getArray(), ovec.getArray()))
        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.ADD_VALUES)
        self.assertTrue(np.allclose(2 * vec.getArray(), ovec.getArray()))

    def testGlobalToLocalVecScatter(self):
        vec = PETSc.Vec().create()
        vec.setSizes((10, None))
        vec.setUp()
        sct, ovec = PETSc.Scatter.toAll(vec)
        self.dm.setGlobalToLocalVecScatter(sct)
        self.dm.globalToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)

    def testLocalToGlobalVecScatter(self):
        vec = PETSc.Vec().create()
        vec.setSizes((10, None))
        vec.setUp()
        sct, ovec = PETSc.Scatter.toAll(vec)
        self.dm.setLocalToGlobalVecScatter(sct)
        self.dm.localToGlobal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)

    def testLocalToLocalVecScatter(self):
        vec = PETSc.Vec().create()
        vec.setSizes((10, None))
        vec.setUp()
        sct, ovec = PETSc.Scatter.toAll(vec)
        self.dm.setLocalToLocalVecScatter(sct)
        self.dm.localToLocal(vec, ovec, addv=PETSc.InsertMode.INSERT_VALUES)

    def testCoarsenRefine(self):
        cdm = PETSc.DMShell().create(comm=self.COMM)

        def coarsen(dm, comm):
            return cdm

        def refine(dm, comm):
            return self.dm

        cdm.setRefine(refine)
        self.dm.setCoarsen(coarsen)
        coarsened = self.dm.coarsen()
        self.assertEqual(coarsened, cdm)
        refined = coarsened.refine()
        self.assertEqual(refined, self.dm)

    def testCreateInterpolation(self):
        mat = PETSc.Mat().create()
        mat.setSizes(((10, None), (10, None)))
        mat.setUp()
        vec = PETSc.Vec().create()
        vec.setSizes((10, None))
        vec.setUp()

        def create_interp(dm, dmf):
            return mat, vec

        self.dm.setCreateInterpolation(create_interp)
        m, v = self.dm.createInterpolation(self.dm)
        self.assertEqual(m, mat)
        self.assertEqual(v, vec)

    def testCreateInjection(self):
        mat = PETSc.Mat().create()
        mat.setSizes(((10, None), (10, None)))
        mat.setUp()

        def create_inject(dm, dmf):
            return mat

        self.dm.setCreateInjection(create_inject)
        m = self.dm.createInjection(self.dm)
        self.assertEqual(m, mat)


if __name__ == '__main__':
    unittest.main()
