1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
try:
import threading as _threading
_HAS_THREADING = True
except ImportError:
import dummy_threading as _threading
_HAS_THREADING = False
Thread = _threading.Thread
try:
current_thread = _threading.current_thread # Py 3.X
except AttributeError:
current_thread = _threading.currentThread # Py 2.X
import mpi4py.rc
mpi4py.rc.thread_level = 'multiple'
from mpi4py import MPI
import mpiunittest as unittest
class TestMPIThreads(unittest.TestCase):
REQUIRED = MPI.THREAD_SERIALIZED
def testThreadLevels(self):
levels = [MPI.THREAD_SINGLE,
MPI.THREAD_FUNNELED,
MPI.THREAD_SERIALIZED,
MPI.THREAD_MULTIPLE]
if None in levels: return
for i in range(len(levels)-1):
self.assertTrue(levels[i] < levels[i+1])
try:
provided = MPI.Query_thread()
self.assertTrue(provided in levels)
except NotImplementedError:
pass
def _test_is(self, main=False):
try:
flag = MPI.Is_thread_main()
except NotImplementedError:
return
self.assertEqual(flag, main)
if _VERBOSE:
from sys import stderr
thread = current_thread()
name = thread.getName()
log = lambda m: stderr.write(m+'\n')
log("%s: MPI.Is_thread_main() -> %s" % (name, flag))
def testIsThreadMain(self):
self._test_is(main=True)
try:
provided = MPI.Query_thread()
except NotImplementedError:
return
if provided < self.REQUIRED:
return
T = []
for i in range(5):
t = Thread(target=self._test_is,
args = (not _HAS_THREADING,))
T.append(t)
if provided == MPI.THREAD_SERIALIZED:
for t in T:
t.start()
t.join()
elif provided == MPI.THREAD_MULTIPLE:
for t in T:
t.start()
for t in T:
t.join()
name, version = MPI.get_vendor()
if name == 'Open MPI':
TestMPIThreads.REQUIRED = MPI.THREAD_MULTIPLE
if name == 'LAM/MPI':
TestMPIThreads.REQUIRED = MPI.THREAD_MULTIPLE
_VERBOSE = False
#_VERBOSE = True
if __name__ == '__main__':
import sys
if '-v' in sys.argv:
_VERBOSE = True
unittest.main()
|