1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
|
from mpi4py import MPI
import mpiunittest as unittest
MPI_ERR_OP = MPI.ERR_OP
try:
import array
except ImportError:
array = None
try:
bytes
except NameError:
bytes = str
if array:
try:
tobytes = array.array.tobytes
except AttributeError:
tobytes = array.array.tostring
def frombytes(typecode, data):
a = array.array(typecode,[])
try:
data = data.tobytes()
except AttributeError:
pass
try:
_frombytes = array.array.frombytes
except AttributeError:
_frombytes = array.array.fromstring
_frombytes(a, data)
return a
def mysum_py(a, b):
for i in range(len(a)):
b[i] = a[i] + b[i]
return b
def mysum(ba, bb, dt):
if dt is None:
return mysum_py(ba, bb)
assert dt == MPI.INT
assert len(ba) == len(bb)
a = frombytes('i', ba)
b = frombytes('i', bb)
b = mysum_py(a, b)
bb[:] = tobytes(b)
class TestOp(unittest.TestCase):
def testConstructor(self):
op = MPI.Op()
self.assertFalse(op)
self.assertEqual(op, MPI.OP_NULL)
def testCreate(self):
for comm in [MPI.COMM_SELF, MPI.COMM_WORLD]:
for commute in [True, False]:
for N in range(4):
# buffer(empty_array) returns
# the same non-NULL pointer !!!
if N == 0: continue
size = comm.Get_size()
rank = comm.Get_rank()
myop = MPI.Op.Create(mysum, commute)
a = array.array('i', [i*(rank+1) for i in range(N)])
b = array.array('i', [0]*len(a))
comm.Allreduce([a, MPI.INT], [b, MPI.INT], myop)
scale = sum(range(1,size+1))
for i in range(N):
self.assertEqual(b[i], scale*i)
ret = myop(a, b)
self.assertTrue(ret is b)
for i in range(N):
self.assertEqual(b[i], a[i]+scale*i)
myop.Free()
def testCreateMany(self):
N = 16 # max user-defined operations
#
ops = []
for i in range(N):
o = MPI.Op.Create(mysum)
ops.append(o)
self.assertRaises(RuntimeError, MPI.Op.Create, mysum)
for o in ops: o.Free() # cleanup
# other round
ops = []
for i in range(N):
o = MPI.Op.Create(mysum)
ops.append(o)
self.assertRaises(RuntimeError, MPI.Op.Create, mysum)
for o in ops: o.Free() # cleanup
def _test_call(self, op, args, res):
self.assertEqual(op(*args), res)
def testCall(self):
self._test_call(MPI.MIN, (2,3), 2)
self._test_call(MPI.MAX, (2,3), 3)
self._test_call(MPI.SUM, (2,3), 5)
self._test_call(MPI.PROD, (2,3), 6)
def xor(x,y): return bool(x) ^ bool(y)
for x, y in ((0, 0),
(0, 1),
(1, 0),
(1, 1)):
self._test_call(MPI.LAND, (x,y), x and y)
self._test_call(MPI.LOR, (x,y), x or y)
self._test_call(MPI.LXOR, (x,y), xor(x, y))
self._test_call(MPI.BAND, (x,y), x & y)
self._test_call(MPI.BOR, (x,y), x | y)
self._test_call(MPI.LXOR, (x,y), x ^ y)
if MPI.REPLACE:
self._test_call(MPI.REPLACE, (2,3), 3)
self._test_call(MPI.REPLACE, (3,2), 2)
if MPI.NO_OP:
self._test_call(MPI.NO_OP, (2,3), 2)
self._test_call(MPI.NO_OP, (3,2), 3)
if not array:
del TestOp.testCreate
if __name__ == '__main__':
unittest.main()
|