File: test_comm_inter.py

package info (click to toggle)
mpi4py 2.0.0-2.1%2Bdeb9u1
  • links: PTS, VCS
  • area: main
  • in suites: stretch
  • size: 2,680 kB
  • sloc: python: 15,291; ansic: 7,099; makefile: 719; f90: 158; sh: 156; cpp: 121
file content (97 lines) | stat: -rw-r--r-- 3,075 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from mpi4py import MPI
import mpiunittest as unittest

class BaseTestIntercomm(object):

    BASECOMM  = MPI.COMM_NULL
    INTRACOMM = MPI.COMM_NULL
    INTERCOMM = MPI.COMM_NULL

    def setUp(self):
        size = self.BASECOMM.Get_size()
        rank = self.BASECOMM.Get_rank()
        if rank < size // 2 :
            self.COLOR = 0
            self.LOCAL_LEADER = 0
            self.REMOTE_LEADER = size // 2
        else:
            self.COLOR = 1
            self.LOCAL_LEADER = 0
            self.REMOTE_LEADER = 0
        self.INTRACOMM = self.BASECOMM.Split(self.COLOR, key=0)
        Create_intercomm = MPI.Intracomm.Create_intercomm
        self.INTERCOMM = Create_intercomm(self.INTRACOMM,
                                          self.LOCAL_LEADER,
                                          self.BASECOMM,
                                          self.REMOTE_LEADER)

    def tearDown(self):
        self.INTRACOMM.Free()
        self.INTERCOMM.Free()
        del self.INTRACOMM
        del self.INTERCOMM

    def testFortran(self):
        intercomm = self.INTERCOMM
        fint = intercomm.py2f()
        newcomm = MPI.Comm.f2py(fint)
        self.assertEqual(newcomm, intercomm)
        self.assertTrue(type(newcomm) is MPI.Intercomm)

    def testLocalGroupSizeRank(self):
        intercomm = self.INTERCOMM
        local_group = intercomm.Get_group()
        self.assertEqual(local_group.size, intercomm.Get_size())
        self.assertEqual(local_group.size, intercomm.size)
        self.assertEqual(local_group.rank, intercomm.Get_rank())
        self.assertEqual(local_group.rank, intercomm.rank)
        local_group.Free()

    def testRemoteGroupSize(self):
        intercomm = self.INTERCOMM
        remote_group = intercomm.Get_remote_group()
        self.assertEqual(remote_group.size, intercomm.Get_remote_size())
        self.assertEqual(remote_group.size, intercomm.remote_size)
        remote_group.Free()

    def testMerge(self):
        basecomm  = self.BASECOMM
        intercomm = self.INTERCOMM
        if basecomm.rank < basecomm.size // 2:
            high = False
        else:
            high = True
        intracomm = intercomm.Merge(high)
        self.assertEqual(intracomm.size, basecomm.size)
        self.assertEqual(intracomm.rank, basecomm.rank)
        intracomm.Free()


class TestIntercomm(BaseTestIntercomm, unittest.TestCase):
    BASECOMM = MPI.COMM_WORLD

class TestIntercommDup(TestIntercomm):
    def setUp(self):
        self.BASECOMM = self.BASECOMM.Dup()
        super(TestIntercommDup, self).setUp()
    def tearDown(self):
        self.BASECOMM.Free()
        del self.BASECOMM
        super(TestIntercommDup, self).tearDown()

class TestIntercommDupDup(TestIntercomm):
    def setUp(self):
        super(TestIntercommDupDup, self).setUp()
        INTERCOMM = self.INTERCOMM
        self.INTERCOMM = self.INTERCOMM.Dup()
        INTERCOMM.Free()


if MPI.COMM_WORLD.Get_size() < 2:
    del TestIntercomm
    del TestIntercommDup
    del TestIntercommDupDup


if __name__ == '__main__':
    unittest.main()