1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
import contextlib
import importlib
import os
import sys
def github():
return os.environ.get("GITHUB_ACTIONS") == "true"
def azure():
return os.environ.get("TF_BUILD") == "True"
def import_MPI():
return importlib.import_module("mpi4py.MPI")
def has_datatype(datatype):
# https://github.com/pmodels/mpich/issues/7341
MPI = import_MPI()
if datatype == MPI.DATATYPE_NULL:
return False
try:
size = datatype.Get_size()
except MPI.Exception as exc:
if exc.Get_error_class() != MPI.ERR_TYPE:
raise
return False
if size in (0, MPI.UNDEFINED):
return False
return True
def appnum():
MPI = import_MPI()
if MPI.APPNUM != MPI.KEYVAL_INVALID:
return MPI.COMM_WORLD.Get_attr(MPI.APPNUM)
return None
def has_mpi_appnum():
return appnum() is not None
def has_mpi_port():
MPI = import_MPI()
try:
port = MPI.Open_port()
except MPI.Exception:
return False
try:
MPI.Close_port(port)
except MPI.Exception:
return False
return True
def disable_mpi_spawn():
MPI = import_MPI()
skip_spawn = os.environ.get("MPI4PY_TEST_SPAWN") in (
"0",
"n",
"no",
"off",
"false",
)
if skip_spawn:
return True
macos = sys.platform == "darwin"
windows = sys.platform == "win32"
name, version = MPI.get_vendor()
if name == "Open MPI":
if version < (3, 0, 0):
return True
if version == (4, 0, 0):
return True
if version == (4, 0, 1) and macos:
return True
if version == (4, 0, 2) and macos:
return True
if (4, 1, 0) <= version < (4, 2, 0):
if azure() or github():
return True
if name == "MPICH":
if (3, 4, 0) <= version < (4, 0, 0):
if macos:
return True
if version < (4, 1):
if not has_mpi_appnum():
return True
if version < (4, 3):
try:
port = MPI.Open_port()
MPI.Close_port(port)
except MPI.Exception:
return True
if name == "Intel MPI":
mpi4py = __import__("mpi4py")
if mpi4py.rc.recv_mprobe:
return True
if MPI.COMM_WORLD.Get_size() > 1 and windows:
return True
if name == "Microsoft MPI":
if version < (8, 1, 0):
return True
if has_mpi_appnum() is not None:
return True
if os.environ.get("PMI_APPNUM") is None:
return True
if name == "MVAPICH":
if version < (3, 0, 0):
return True
if not has_mpi_appnum():
return True
if name == "MPICH2":
if not has_mpi_appnum():
return True
if MPI.Get_version() < (2, 0):
return True
if any(map(sys.modules.get, ("cupy", "numba"))):
return True
#
return False
@contextlib.contextmanager
def capture_stderr():
stderr = sys.stderr
stream = __import__("io").StringIO()
sys.stderr = stream
try:
yield stream
finally:
sys.stderr = stderr
|