File: decorators.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (289 lines) | stat: -rw-r--r-- 8,309 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import os
import sys
import warnings
from importlib import import_module
from importlib.util import find_spec
from typing import Callable

import torch
from packaging.requirements import Requirement
from packaging.version import Version

import torch_geometric
from torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE
from torch_geometric.visualization.graph import has_graphviz


def is_full_test() -> bool:
    r"""Whether to run the full but time-consuming test suite."""
    return os.getenv('FULL_TEST', '0') == '1'


def onlyFullTest(func: Callable) -> Callable:
    r"""A decorator to specify that this function belongs to the full test
    suite.
    """
    import pytest
    return pytest.mark.skipif(
        not is_full_test(),
        reason="Fast test run",
    )(func)


def is_distributed_test() -> bool:
    r"""Whether to run the distributed test suite."""
    return ((is_full_test() or os.getenv('DIST_TEST', '0') == '1')
            and sys.platform == 'linux' and has_package('pyg_lib'))


def onlyDistributedTest(func: Callable) -> Callable:
    r"""A decorator to specify that this function belongs to the distributed
    test suite.
    """
    import pytest
    return pytest.mark.skipif(
        not is_distributed_test(),
        reason="Fast test run",
    )(func)


def onlyLinux(func: Callable) -> Callable:
    r"""A decorator to specify that this function should only execute on
    Linux systems.
    """
    import pytest
    return pytest.mark.skipif(
        sys.platform != 'linux',
        reason="No Linux system",
    )(func)


def noWindows(func: Callable) -> Callable:
    r"""A decorator to specify that this function should not execute on
    Windows systems.
    """
    import pytest
    return pytest.mark.skipif(
        os.name == 'nt',
        reason="Windows system",
    )(func)


def noMac(func: Callable) -> Callable:
    r"""A decorator to specify that this function should not execute on
    macOS systems.
    """
    import pytest
    return pytest.mark.skipif(
        sys.platform == 'darwin',
        reason="macOS system",
    )(func)


def minPython(version: str) -> Callable:
    r"""A decorator to run tests on specific :python:`Python` versions only."""
    def decorator(func: Callable) -> Callable:
        import pytest

        major, minor = version.split('.')

        skip = False
        if sys.version_info.major < int(major):
            skip = True
        if (sys.version_info.major == int(major)
                and sys.version_info.minor < int(minor)):
            skip = True

        return pytest.mark.skipif(
            skip,
            reason=f"Python {version} required",
        )(func)

    return decorator


def onlyCUDA(func: Callable) -> Callable:
    r"""A decorator to skip tests if CUDA is not found."""
    import pytest
    return pytest.mark.skipif(
        not torch.cuda.is_available(),
        reason="CUDA not available",
    )(func)


def onlyXPU(func: Callable) -> Callable:
    r"""A decorator to skip tests if XPU is not found."""
    import pytest
    return pytest.mark.skipif(
        not torch_geometric.is_xpu_available(),
        reason="XPU not available",
    )(func)


def onlyOnline(func: Callable) -> Callable:
    r"""A decorator to skip tests if there exists no connection to the
    internet.
    """
    import http.client as httplib

    import pytest

    has_connection = True
    connection = httplib.HTTPSConnection('8.8.8.8', timeout=5)
    try:
        connection.request('HEAD', '/')
    except Exception:
        has_connection = False
    finally:
        connection.close()

    return pytest.mark.skipif(
        not has_connection,
        reason="No internet connection",
    )(func)


def onlyGraphviz(func: Callable) -> Callable:
    r"""A decorator to specify that this function should only execute in case
    :obj:`graphviz` is installed.
    """
    import pytest
    return pytest.mark.skipif(
        not has_graphviz(),
        reason="Graphviz not installed",
    )(func)


def onlyNeighborSampler(func: Callable) -> Callable:
    r"""A decorator to skip tests if no neighborhood sampler package is
    installed.
    """
    import pytest
    return pytest.mark.skipif(
        not WITH_PYG_LIB and not WITH_TORCH_SPARSE,
        reason="No neighbor sampler installed",
    )(func)


def has_package(package: str) -> bool:
    r"""Returns :obj:`True` in case :obj:`package` is installed."""
    if '|' in package:
        return any(has_package(p) for p in package.split('|'))

    req = Requirement(package)
    if find_spec(req.name) is None:
        return False

    try:
        module = import_module(req.name)
        if not hasattr(module, '__version__'):
            return True

        version = Version(module.__version__).base_version
        return version in req.specifier
    except Exception:
        return False


def withPackage(*args: str) -> Callable:
    r"""A decorator to skip tests if certain packages are not installed.
    Also supports version specification.
    """
    na_packages = {package for package in args if not has_package(package)}

    if len(na_packages) == 1:
        reason = f"Package {list(na_packages)[0]} not found"
    else:
        reason = f"Packages {na_packages} not found"

    def decorator(func: Callable) -> Callable:
        import pytest
        return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func)

    return decorator


def withCUDA(func: Callable) -> Callable:
    r"""A decorator to test both on CPU and CUDA (if available)."""
    import pytest

    devices = [pytest.param(torch.device('cpu'), id='cpu')]
    if torch.cuda.is_available():
        devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))

    return pytest.mark.parametrize('device', devices)(func)


def withDevice(func: Callable) -> Callable:
    r"""A decorator to test on all available tensor processing devices."""
    import pytest

    devices = [pytest.param(torch.device('cpu'), id='cpu')]

    if torch.cuda.is_available():
        devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))

    if torch_geometric.is_mps_available():
        devices.append(pytest.param(torch.device('mps:0'), id='mps'))

    if torch_geometric.is_xpu_available():
        devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))

    # Additional devices can be registered through environment variables:
    device = os.getenv('TORCH_DEVICE')
    if device:
        backend = os.getenv('TORCH_BACKEND')
        if backend is None:
            warnings.warn(f"Please specify the backend via 'TORCH_BACKEND' in"
                          f"order to test against '{device}'")
        else:
            import_module(backend)
            devices.append(pytest.param(torch.device(device), id=device))

    return pytest.mark.parametrize('device', devices)(func)


def withMETIS(func: Callable) -> Callable:
    r"""A decorator to only test in case a valid METIS method is available."""
    import pytest

    with_metis = WITH_METIS

    if with_metis:
        try:  # Test that METIS can succesfully execute:
            # TODO Using `pyg-lib` metis partitioning leads to some weird bugs
            # in the # CI. As such, we require `torch-sparse` for now.
            rowptr = torch.tensor([0, 2, 4, 6])
            col = torch.tensor([1, 2, 0, 2, 1, 0])
            torch.ops.torch_sparse.partition(rowptr, col, None, 2, True)
        except Exception:
            with_metis = False

    return pytest.mark.skipif(
        not with_metis,
        reason="METIS not enabled",
    )(func)


def disableExtensions(func: Callable) -> Callable:
    r"""A decorator to temporarily disable the usage of the
    :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension
    packages.
    """
    import pytest

    return pytest.mark.usefixtures('disable_extensions')(func)


def withoutExtensions(func: Callable) -> Callable:
    r"""A decorator to test both with and without the usage of extension
    packages such as :obj:`torch_scatter`, :obj:`torch_sparse` and
    :obj:`pyg_lib`.
    """
    import pytest

    return pytest.mark.parametrize(
        'without_extensions',
        ['enable_extensions', 'disable_extensions'],
        indirect=True,
    )(func)