1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
import sys
try:
import threading
HAVE_THREADING = True
except ImportError:
import dummy_threading as threading
HAVE_THREADING = False
VERBOSE = False
# VERBOSE = True
import mpi4py.rc # noqa: E402
mpi4py.rc.thread_level = "multiple"
import mpiunittest as unittest # noqa: E402
from mpi4py import MPI # noqa: E402
class TestMPIThreads(unittest.TestCase):
#
def testThreadLevels(self):
levels = [
MPI.THREAD_SINGLE,
MPI.THREAD_FUNNELED,
MPI.THREAD_SERIALIZED,
MPI.THREAD_MULTIPLE,
]
for i in range(len(levels) - 1):
self.assertLess(levels[i], levels[i + 1])
try:
provided = MPI.Query_thread()
self.assertIn(provided, levels)
except NotImplementedError:
self.skipTest("mpi-query_thread")
def testIsThreadMain(self):
try:
flag = MPI.Is_thread_main()
except NotImplementedError:
self.skipTest("mpi-is_thread_main")
name = threading.current_thread().name
main = (name == "MainThread") or not HAVE_THREADING
self.assertEqual(flag, main)
if VERBOSE:
def log(m):
return sys.stderr.write(m + "\n")
log(f"{name}: MPI.Is_thread_main() -> {flag}")
def testIsThreadMainInThread(self):
try:
provided = MPI.Query_thread()
except NotImplementedError:
self.skipTest("mpi-query_thread")
self.testIsThreadMain()
T = [threading.Thread(target=self.testIsThreadMain) for _ in range(5)]
if provided == MPI.THREAD_MULTIPLE:
for t in T:
t.start()
for t in T:
t.join()
elif provided == MPI.THREAD_SERIALIZED:
for t in T:
t.start()
t.join()
else:
self.skipTest("mpi-thread_level")
if __name__ == "__main__":
unittest.main()
|