1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
|
import importlib
import os.path as osp
import torch
__version__ = '0.6.18'
for library in [
'_version', '_convert', '_diag', '_spmm', '_metis', '_rw', '_saint',
'_sample', '_ego_sample', '_hgt_sample', '_neighbor_sample', '_relabel'
]:
cuda_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cuda', [osp.dirname(__file__)])
cpu_spec = importlib.machinery.PathFinder().find_spec(
f'{library}_cpu', [osp.dirname(__file__)])
spec = cuda_spec or cpu_spec
if spec is not None:
torch.ops.load_library(spec.origin)
else: # pragma: no cover
raise ImportError(f"Could not find module '{library}_cpu' in "
f"{osp.dirname(__file__)}")
cuda_version = torch.ops.torch_sparse.cuda_version()
if torch.version.cuda is not None and cuda_version != -1: # pragma: no cover
if cuda_version < 10000:
major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2])
else:
major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3])
t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')]
if t_major != major:
raise RuntimeError(
f'Detected that PyTorch and torch_sparse were compiled with '
f'different CUDA versions. PyTorch has CUDA version '
f'{t_major}.{t_minor} and torch_sparse has CUDA version '
f'{major}.{minor}. Please reinstall the torch_sparse that '
f'matches your PyTorch install.')
from .storage import SparseStorage # noqa
from .tensor import SparseTensor # noqa
from .transpose import t # noqa
from .narrow import narrow, __narrow_diag__ # noqa
from .select import select # noqa
from .index_select import index_select, index_select_nnz # noqa
from .masked_select import masked_select, masked_select_nnz # noqa
from .permute import permute # noqa
from .diag import remove_diag, set_diag, fill_diag, get_diag # noqa
from .add import add, add_, add_nnz, add_nnz_ # noqa
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
from .reduce import sum, mean, min, max # noqa
from .matmul import matmul # noqa
from .cat import cat # noqa
from .rw import random_walk # noqa
from .metis import partition # noqa
from .bandwidth import reverse_cuthill_mckee # noqa
from .saint import saint_subgraph # noqa
from .sample import sample, sample_adj # noqa
from .convert import to_torch_sparse, from_torch_sparse # noqa
from .convert import to_scipy, from_scipy # noqa
from .coalesce import coalesce # noqa
from .transpose import transpose # noqa
from .eye import eye # noqa
from .spmm import spmm # noqa
from .spspmm import spspmm # noqa
from .spadd import spadd # noqa
__all__ = [
'SparseStorage',
'SparseTensor',
't',
'narrow',
'__narrow_diag__',
'select',
'index_select',
'index_select_nnz',
'masked_select',
'masked_select_nnz',
'permute',
'remove_diag',
'set_diag',
'fill_diag',
'get_diag',
'add',
'add_',
'add_nnz',
'add_nnz_',
'mul',
'mul_',
'mul_nnz',
'mul_nnz_',
'sum',
'mean',
'min',
'max',
'matmul',
'cat',
'random_walk',
'partition',
'reverse_cuthill_mckee',
'saint_subgraph',
'to_torch_sparse',
'from_torch_sparse',
'to_scipy',
'from_scipy',
'coalesce',
'transpose',
'eye',
'spmm',
'spspmm',
'spadd',
'__version__',
]
|