File: test_cco_ngh_obj.py

package info (click to toggle)
mpi4py 1.3.1%2Bhg20131106-2
  • links: PTS, VCS
  • area: main
  • in suites: jessie, jessie-kfreebsd
  • size: 2,224 kB
  • ctags: 6,415
  • sloc: python: 12,056; ansic: 7,022; makefile: 697; f90: 158; cpp: 103; sh: 60
file content (115 lines) | stat: -rw-r--r-- 3,210 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from mpi4py import MPI
import mpiunittest as unittest

_basic = [None,
          True, False,
          -7, 0, 7, 2**31,
          -2**63, 2**63-1,
          -2.17, 0.0, 3.14,
          1+2j, 2-3j,
          'mpi4py',
          ]
messages = _basic
messages += [ list(_basic),
              tuple(_basic),
              dict([('k%d' % key, val)
                    for key, val in enumerate(_basic)])
              ]
messages = messages + [messages]

def create_topo_comms(comm):
    size = comm.Get_size()
    rank = comm.Get_rank()
    # Cartesian
    n = int(size**1/2.0)
    m = int(size**1/3.0)
    if m*m*m == size:
        dims = [m, m, m]
    elif n*n == size:
        dims = [n, n]
    else:
        dims = [size]
    periods = [True] * len(dims)
    yield comm.Create_cart(dims, periods=periods)
    # Graph
    index, edges = [0], []
    for i in range(size):
        pos = index[-1]
        index.append(pos+2)
        edges.append((i-1)%size)
        edges.append((i+1)%size)
    yield comm.Create_graph(index, edges)
    # Dist Graph
    sources = [(rank-2)%size, (rank-1)%size]
    destinations = [(rank+1)%size, (rank+2)%size]
    yield comm.Create_dist_graph_adjacent(sources, destinations)

def get_neighbors_count(comm):
    topo = comm.Get_topology()
    if topo == MPI.CART:
        ndim = comm.Get_dim()
        return 2*ndim, 2*ndim
    if topo == MPI.GRAPH:
        rank = comm.Get_rank()
        nneighbors = comm.Get_neighbors_count(rank)
        return nneighbors, nneighbors
    if topo == MPI.DIST_GRAPH:
        indeg, outdeg, w = comm.Get_dist_neighbors_count()
        return indeg, outdeg
    return 0, 0


class BaseTestCCONghObj(object):

    COMM = MPI.COMM_NULL

    def testNeighborAllgather(self):
        for comm in create_topo_comms(self.COMM):
            rsize, ssize = get_neighbors_count(comm)
            for smess in messages:
                rmess = comm.neighbor_allgather(smess)
                self.assertEqual(rmess, [smess] * rsize)
            comm.Free()

    def testNeighborAlltoall(self):
        for comm in create_topo_comms(self.COMM):
            rsize, ssize = get_neighbors_count(comm)
            for smess in messages:
                rmess = comm.neighbor_alltoall([smess] * ssize)
                self.assertEqual(rmess, [smess] * rsize)
            comm.Free()


class TestCCONghObjSelf(BaseTestCCONghObj, unittest.TestCase):
    COMM = MPI.COMM_SELF

class TestCCONghObjWorld(BaseTestCCONghObj, unittest.TestCase):
    COMM = MPI.COMM_WORLD

class TestCCONghObjSelfDup(BaseTestCCONghObj, unittest.TestCase):
    def setUp(self):
        self.COMM = MPI.COMM_SELF.Dup()
    def tearDown(self):
        self.COMM.Free()

class TestCCONghObjWorldDup(BaseTestCCONghObj, unittest.TestCase):
    def setUp(self):
        self.COMM = MPI.COMM_WORLD.Dup()
    def tearDown(self):
        self.COMM.Free()


cartcomm = MPI.COMM_SELF.Create_cart([1], periods=[1])
try:
    cartcomm.neighbor_allgather(None)
except NotImplementedError:
    del BaseTestCCONghObj
    del TestCCONghObjSelf
    del TestCCONghObjWorld
    del TestCCONghObjSelfDup
    del TestCCONghObjWorldDup
finally:
    cartcomm.Free()

if __name__ == '__main__':
    unittest.main()