File: conftest.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (97 lines) | stat: -rw-r--r-- 3,354 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import functools
import logging
import os.path as osp
from typing import Callable

import pytest
import torch

import torch_geometric.typing
from torch_geometric.data import Dataset
from torch_geometric.io import fs


def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset:
    r"""Returns a variety of datasets according to :obj:`name`."""
    if 'karate' in name.lower():
        from torch_geometric.datasets import KarateClub
        return KarateClub(*args, **kwargs)
    if name.lower() in ['cora', 'citeseer', 'pubmed']:
        from torch_geometric.datasets import Planetoid
        path = osp.join(root, 'Planetoid', name)
        return Planetoid(path, name, *args, **kwargs)
    if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']:
        from torch_geometric.datasets import TUDataset
        path = osp.join(root, 'TUDataset')
        return TUDataset(path, name, *args, **kwargs)
    if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']:
        from torch_geometric.datasets import SNAPDataset
        path = osp.join(root, 'SNAPDataset')
        return SNAPDataset(path, name, *args, **kwargs)
    if name.lower() in ['bashapes']:
        from torch_geometric.datasets import BAShapes
        return BAShapes(*args, **kwargs)
    if name in ['citationCiteseer', 'illc1850']:
        from torch_geometric.datasets import SuiteSparseMatrixCollection
        path = osp.join(root, 'SuiteSparseMatrixCollection')
        return SuiteSparseMatrixCollection(path, name=name, *args, **kwargs)
    if 'elliptic' in name.lower():
        from torch_geometric.datasets import EllipticBitcoinDataset
        path = osp.join(root, 'EllipticBitcoinDataset')
        return EllipticBitcoinDataset(path, *args, **kwargs)
    if name.lower() in ['hetero']:
        from torch_geometric.testing import FakeHeteroDataset
        return FakeHeteroDataset(*args, **kwargs)

    raise ValueError(f"Cannot load dataset with name '{name}'")


@pytest.fixture(scope='session')
def get_dataset() -> Callable:
    # TODO Support memory filesystem on Windows.
    if torch_geometric.typing.WITH_WINDOWS:
        root = osp.join('/', 'tmp', 'pyg_test_datasets')
    else:
        root = 'memory://pyg_test_datasets'

    yield functools.partial(load_dataset, root)

    if fs.exists(root):
        fs.rm(root)


@pytest.fixture
def enable_extensions():  # Nothing to do.
    yield


@pytest.fixture
def disable_extensions():
    def is_setting(name: str) -> bool:
        if not name.startswith('WITH_'):
            return False
        if name.startswith('WITH_PT') or name.startswith('WITH_WINDOWS'):
            return False
        return True

    settings = dir(torch_geometric.typing)
    settings = [key for key in settings if is_setting(key)]
    state = {key: getattr(torch_geometric.typing, key) for key in settings}

    for key in state.keys():
        setattr(torch_geometric.typing, key, False)
    yield
    for key, value in state.items():
        setattr(torch_geometric.typing, key, value)


@pytest.fixture
def without_extensions(request):
    request.getfixturevalue(request.param)
    return request.param == 'disable_extensions'


@pytest.fixture(scope='function')
def spawn_context():
    torch.multiprocessing.set_start_method('spawn', force=True)
    logging.info("Setting torch.multiprocessing context to 'spawn'")