1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
from collections import namedtuple
import contextlib
import unittest
class TestCase(unittest.TestCase):
def assertAlmostEqual(self, first, second):
num = complex(second) - complex(first)
den = max(abs(complex(second)), abs(complex(first))) or 1.0
if (abs(num/den) > 1e-2):
raise self.failureException(f'{first!r} != {second!r}')
@contextlib.contextmanager
def catchNotImplementedError(self, version=None, subversion=0):
try:
yield
except NotImplementedError:
if version is not None:
from mpi4py import MPI
mpi_version = (MPI.VERSION, MPI.SUBVERSION)
self.assertLess(mpi_version, (version, subversion))
_Version = namedtuple("_Version", ["major", "minor", "patch"])
def _parse_version(version):
version = tuple(map(int, version.split('.'))) + (0, 0, 0)
return _Version(*version[:3])
class _VersionPredicate:
def __init__(self, versionPredicateStr):
import re
re_name = re.compile(r"(?i)^([a-z_]\w*(?:\.[a-z_]\w*)*)(.*)$")
re_pred = re.compile(r"^(<=|>=|<|>|!=|==)(.*)$")
def split(item):
m = re_pred.match(item)
op, version = m.groups()
version = _parse_version(version)
return op, version
vpstr = versionPredicateStr.replace(' ', '')
m = re_name.match(vpstr)
name, plist = m.groups()
if plist:
assert plist[0] == '(' and plist[-1] == ')'
plist = plist[1:-1]
pred = [split(p) for p in plist.split(',') if p]
self.name = name
self.pred = pred
def __str__(self):
if self.pred:
items = [f"{op}{'.'.join(map(str, ver))}" for op, ver in self.pred]
return f"{self.name}({','.join(items)})"
else:
return self.name
def satisfied_by(self, version):
from operator import lt, le, gt, ge, eq, ne
opmap = {'<': lt, '<=': le, '>': gt, '>=': ge, '==': eq, '!=': ne}
version = _parse_version(version)
for op, ver in self.pred:
if not opmap[op](version, ver):
return False
return True
def mpi_predicate(predicate):
from mpi4py import MPI
def key(s):
s = s.replace(' ', '')
s = s.replace('/', '')
s = s.replace('-', '')
s = s.replace('Intel', 'I')
s = s.replace('Microsoft', 'MS')
return s.lower()
vp = _VersionPredicate(key(predicate))
if vp.name == 'mpi':
name, version = 'mpi', MPI.Get_version()
version = version + (0,)
else:
name, version = MPI.get_vendor()
if vp.name == key(name):
x, y, z = version
if vp.satisfied_by(f'{x}.{y}.{z}'):
return vp
return None
def is_mpi(predicate):
return mpi_predicate(predicate)
def is_mpi_gpu(predicate, array):
if array.backend in ('cupy', 'numba', 'dlpack-cupy'):
if mpi_predicate(predicate):
return True
return False
SkipTest = unittest.SkipTest
skip = unittest.skip
skipIf = unittest.skipIf
skipUnless = unittest.skipUnless
def skipMPI(predicate, *conditions):
version = mpi_predicate(predicate)
if version:
if not conditions or any(conditions):
return unittest.skip(str(version))
return unittest.skipIf(False, '')
def disable(what, reason):
return unittest.skip(reason)(what)
@contextlib.contextmanager
def capture_stderr():
import io
import sys
stderr = sys.stderr
stream = io.StringIO()
sys.stderr = stream
try:
yield stream
finally:
sys.stderr = stderr
def main(*args, **kwargs):
from main import main
try:
main(*args, **kwargs)
except SystemExit:
pass
|