1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289
|
import os
import sys
import warnings
from importlib import import_module
from importlib.util import find_spec
from typing import Callable
import torch
from packaging.requirements import Requirement
from packaging.version import Version
import torch_geometric
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
from torch_geometric.visualization.graph import has_graphviz
def is_full_test() -> bool:
r"""Whether to run the full but time-consuming test suite."""
return os.getenv('FULL_TEST', '0') == '1'
def onlyFullTest(func: Callable) -> Callable:
r"""A decorator to specify that this function belongs to the full test
suite.
"""
import pytest
return pytest.mark.skipif(
not is_full_test(),
reason="Fast test run",
)(func)
def is_distributed_test() -> bool:
r"""Whether to run the distributed test suite."""
return ((is_full_test() or os.getenv('DIST_TEST', '0') == '1')
and sys.platform == 'linux' and has_package('pyg_lib'))
def onlyDistributedTest(func: Callable) -> Callable:
r"""A decorator to specify that this function belongs to the distributed
test suite.
"""
import pytest
return pytest.mark.skipif(
not is_distributed_test(),
reason="Fast test run",
)(func)
def onlyLinux(func: Callable) -> Callable:
r"""A decorator to specify that this function should only execute on
Linux systems.
"""
import pytest
return pytest.mark.skipif(
sys.platform != 'linux',
reason="No Linux system",
)(func)
def noWindows(func: Callable) -> Callable:
r"""A decorator to specify that this function should not execute on
Windows systems.
"""
import pytest
return pytest.mark.skipif(
os.name == 'nt',
reason="Windows system",
)(func)
def noMac(func: Callable) -> Callable:
r"""A decorator to specify that this function should not execute on
macOS systems.
"""
import pytest
return pytest.mark.skipif(
sys.platform == 'darwin',
reason="macOS system",
)(func)
def minPython(version: str) -> Callable:
r"""A decorator to run tests on specific :python:`Python` versions only."""
def decorator(func: Callable) -> Callable:
import pytest
major, minor = version.split('.')
skip = False
if sys.version_info.major < int(major):
skip = True
if (sys.version_info.major == int(major)
and sys.version_info.minor < int(minor)):
skip = True
return pytest.mark.skipif(
skip,
reason=f"Python {version} required",
)(func)
return decorator
def onlyCUDA(func: Callable) -> Callable:
r"""A decorator to skip tests if CUDA is not found."""
import pytest
return pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
)(func)
def onlyXPU(func: Callable) -> Callable:
r"""A decorator to skip tests if XPU is not found."""
import pytest
return pytest.mark.skipif(
not torch_geometric.is_xpu_available(),
reason="XPU not available",
)(func)
def onlyOnline(func: Callable) -> Callable:
r"""A decorator to skip tests if there exists no connection to the
internet.
"""
import http.client as httplib
import pytest
has_connection = True
connection = httplib.HTTPSConnection('8.8.8.8', timeout=5)
try:
connection.request('HEAD', '/')
except Exception:
has_connection = False
finally:
connection.close()
return pytest.mark.skipif(
not has_connection,
reason="No internet connection",
)(func)
def onlyGraphviz(func: Callable) -> Callable:
r"""A decorator to specify that this function should only execute in case
:obj:`graphviz` is installed.
"""
import pytest
return pytest.mark.skipif(
not has_graphviz(),
reason="Graphviz not installed",
)(func)
def onlyNeighborSampler(func: Callable) -> Callable:
r"""A decorator to skip tests if no neighborhood sampler package is
installed.
"""
import pytest
return pytest.mark.skipif(
not WITH_PYG_LIB and not WITH_TORCH_SPARSE,
reason="No neighbor sampler installed",
)(func)
def has_package(package: str) -> bool:
r"""Returns :obj:`True` in case :obj:`package` is installed."""
if '|' in package:
return any(has_package(p) for p in package.split('|'))
req = Requirement(package)
if find_spec(req.name) is None:
return False
try:
module = import_module(req.name)
if not hasattr(module, '__version__'):
return True
version = Version(module.__version__).base_version
return version in req.specifier
except Exception:
return False
def withPackage(*args: str) -> Callable:
r"""A decorator to skip tests if certain packages are not installed.
Also supports version specification.
"""
na_packages = {package for package in args if not has_package(package)}
if len(na_packages) == 1:
reason = f"Package {list(na_packages)[0]} not found"
else:
reason = f"Packages {na_packages} not found"
def decorator(func: Callable) -> Callable:
import pytest
return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func)
return decorator
def withCUDA(func: Callable) -> Callable:
r"""A decorator to test both on CPU and CUDA (if available)."""
import pytest
devices = [pytest.param(torch.device('cpu'), id='cpu')]
if torch.cuda.is_available():
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
return pytest.mark.parametrize('device', devices)(func)
def withDevice(func: Callable) -> Callable:
r"""A decorator to test on all available tensor processing devices."""
import pytest
devices = [pytest.param(torch.device('cpu'), id='cpu')]
if torch.cuda.is_available():
devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))
if torch_geometric.is_mps_available():
devices.append(pytest.param(torch.device('mps:0'), id='mps'))
if torch_geometric.is_xpu_available():
devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))
# Additional devices can be registered through environment variables:
device = os.getenv('TORCH_DEVICE')
if device:
backend = os.getenv('TORCH_BACKEND')
if backend is None:
warnings.warn(f"Please specify the backend via 'TORCH_BACKEND' in"
f"order to test against '{device}'")
else:
import_module(backend)
devices.append(pytest.param(torch.device(device), id=device))
return pytest.mark.parametrize('device', devices)(func)
def withMETIS(func: Callable) -> Callable:
r"""A decorator to only test in case a valid METIS method is available."""
import pytest
with_metis = WITH_METIS
if with_metis:
try: # Test that METIS can succesfully execute:
# TODO Using `pyg-lib` metis partitioning leads to some weird bugs
# in the # CI. As such, we require `torch-sparse` for now.
rowptr = torch.tensor([0, 2, 4, 6])
col = torch.tensor([1, 2, 0, 2, 1, 0])
torch.ops.torch_sparse.partition(rowptr, col, None, 2, True)
except Exception:
with_metis = False
return pytest.mark.skipif(
not with_metis,
reason="METIS not enabled",
)(func)
def disableExtensions(func: Callable) -> Callable:
r"""A decorator to temporarily disable the usage of the
:obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
packages.
"""
import pytest
return pytest.mark.usefixtures('disable_extensions')(func)
def withoutExtensions(func: Callable) -> Callable:
r"""A decorator to test both with and without the usage of extension
packages such as :obj:`torch_scatter`, :obj:`torch_sparse` and
:obj:`pyg_lib`.
"""
import pytest
return pytest.mark.parametrize(
'without_extensions',
['enable_extensions', 'disable_extensions'],
indirect=True,
)(func)
|