File: typing.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (384 lines) | stat: -rw-r--r-- 13,865 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
import inspect
import os
import sys
import typing
import warnings
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
from torch import Tensor

WITH_PT20 = int(torch.__version__.split('.')[0]) >= 2
WITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1
WITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2
WITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3
WITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4
WITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5
WITH_PT111 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 11
WITH_PT112 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 12
WITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13

WITH_WINDOWS = os.name == 'nt'
NO_MKL = 'USE_MKL=OFF' in torch.__config__.show() or WITH_WINDOWS

MAX_INT64 = torch.iinfo(torch.int64).max

if WITH_PT20:
    INDEX_DTYPES: Set[torch.dtype] = {
        torch.int32,
        torch.int64,
    }
elif not typing.TYPE_CHECKING:  # pragma: no cover
    INDEX_DTYPES: Set[torch.dtype] = {
        torch.int64,
    }

if not hasattr(torch, 'sparse_csc'):
    torch.sparse_csc = torch.sparse_coo

try:
    import pyg_lib  # noqa
    WITH_PYG_LIB = True
    WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, 'grouped_matmul')
    WITH_SEGMM = hasattr(pyg_lib.ops, 'segment_matmul')
    if WITH_SEGMM and 'pytest' in sys.modules and torch.cuda.is_available():
        # NOTE `segment_matmul` is currently bugged on older NVIDIA cards which
        # let our GPU tests on CI crash. Try if this error is present on the
        # current GPU and disable `WITH_SEGMM`/`WITH_GMM` if necessary.
        # TODO Drop this code block once `segment_matmul` is fixed.
        try:
            x = torch.randn(3, 4, device='cuda')
            ptr = torch.tensor([0, 2, 3], device='cuda')
            weight = torch.randn(2, 4, 4, device='cuda')
            out = pyg_lib.ops.segment_matmul(x, ptr, weight)
        except RuntimeError:
            WITH_GMM = False
            WITH_SEGMM = False
    WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add')
    WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr')
    WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
    WITH_METIS = hasattr(pyg_lib, 'partition')
    WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature(
        pyg_lib.sampler.neighbor_sample).parameters)
    WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
        pyg_lib.sampler.neighbor_sample).parameters)
except Exception as e:
    if not isinstance(e, ImportError):  # pragma: no cover
        warnings.warn(f"An issue occurred while importing 'pyg-lib'. "
                      f"Disabling its usage. Stacktrace: {e}")
    pyg_lib = object
    WITH_PYG_LIB = False
    WITH_GMM = False
    WITH_SEGMM = False
    WITH_SAMPLED_OP = False
    WITH_SOFTMAX = False
    WITH_INDEX_SORT = False
    WITH_METIS = False
    WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
    WITH_WEIGHTED_NEIGHBOR_SAMPLE = False

try:
    import torch_scatter  # noqa
    WITH_TORCH_SCATTER = True
except Exception as e:
    if not isinstance(e, ImportError):  # pragma: no cover
        warnings.warn(f"An issue occurred while importing 'torch-scatter'. "
                      f"Disabling its usage. Stacktrace: {e}")
    torch_scatter = object
    WITH_TORCH_SCATTER = False

try:
    import torch_cluster  # noqa
    WITH_TORCH_CLUSTER = True
    WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
except Exception as e:
    if not isinstance(e, ImportError):  # pragma: no cover
        warnings.warn(f"An issue occurred while importing 'torch-cluster'. "
                      f"Disabling its usage. Stacktrace: {e}")
    WITH_TORCH_CLUSTER = False
    WITH_TORCH_CLUSTER_BATCH_SIZE = False

    class TorchCluster:
        def __getattr__(self, key: str) -> Any:
            raise ImportError(f"'{key}' requires 'torch-cluster'")

    torch_cluster = TorchCluster()

try:
    import torch_spline_conv  # noqa
    WITH_TORCH_SPLINE_CONV = True
except Exception as e:
    if not isinstance(e, ImportError):  # pragma: no cover
        warnings.warn(
            f"An issue occurred while importing 'torch-spline-conv'. "
            f"Disabling its usage. Stacktrace: {e}")
    WITH_TORCH_SPLINE_CONV = False

try:
    import torch_sparse  # noqa
    from torch_sparse import SparseStorage, SparseTensor
    WITH_TORCH_SPARSE = True
except Exception as e:
    if not isinstance(e, ImportError):  # pragma: no cover
        warnings.warn(f"An issue occurred while importing 'torch-sparse'. "
                      f"Disabling its usage. Stacktrace: {e}")
    WITH_TORCH_SPARSE = False

    class SparseStorage:  # type: ignore
        def __init__(
            self,
            row: Optional[Tensor] = None,
            rowptr: Optional[Tensor] = None,
            col: Optional[Tensor] = None,
            value: Optional[Tensor] = None,
            sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
            rowcount: Optional[Tensor] = None,
            colptr: Optional[Tensor] = None,
            colcount: Optional[Tensor] = None,
            csr2csc: Optional[Tensor] = None,
            csc2csr: Optional[Tensor] = None,
            is_sorted: bool = False,
            trust_data: bool = False,
        ):
            raise ImportError("'SparseStorage' requires 'torch-sparse'")

        def value(self) -> Optional[Tensor]:
            raise ImportError("'SparseStorage' requires 'torch-sparse'")

        def rowcount(self) -> Tensor:
            raise ImportError("'SparseStorage' requires 'torch-sparse'")

    class SparseTensor:  # type: ignore
        def __init__(
            self,
            row: Optional[Tensor] = None,
            rowptr: Optional[Tensor] = None,
            col: Optional[Tensor] = None,
            value: Optional[Tensor] = None,
            sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
            is_sorted: bool = False,
            trust_data: bool = False,
        ):
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        @classmethod
        def from_edge_index(
            self,
            edge_index: Tensor,
            edge_attr: Optional[Tensor] = None,
            sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
            is_sorted: bool = False,
            trust_data: bool = False,
        ) -> 'SparseTensor':
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        @property
        def storage(self) -> SparseStorage:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        @classmethod
        def from_dense(self, mat: Tensor,
                       has_value: bool = True) -> 'SparseTensor':
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def size(self, dim: int) -> int:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def nnz(self) -> int:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def is_cuda(self) -> bool:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def has_value(self) -> bool:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def set_value(self, value: Optional[Tensor],
                      layout: Optional[str] = None) -> 'SparseTensor':
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def fill_value(self, fill_value: float,
                       dtype: Optional[torch.dtype] = None) -> 'SparseTensor':
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def csr(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def requires_grad(self) -> bool:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

        def to_torch_sparse_csr_tensor(
            self,
            dtype: Optional[torch.dtype] = None,
        ) -> Tensor:
            raise ImportError("'SparseTensor' requires 'torch-sparse'")

    class torch_sparse:  # type: ignore
        @staticmethod
        def matmul(src: SparseTensor, other: Tensor,
                   reduce: str = "sum") -> Tensor:
            raise ImportError("'matmul' requires 'torch-sparse'")

        @staticmethod
        def sum(src: SparseTensor, dim: Optional[int] = None) -> Tensor:
            raise ImportError("'sum' requires 'torch-sparse'")

        @staticmethod
        def mul(src: SparseTensor, other: Tensor) -> SparseTensor:
            raise ImportError("'mul' requires 'torch-sparse'")

        @staticmethod
        def set_diag(src: SparseTensor, values: Optional[Tensor] = None,
                     k: int = 0) -> SparseTensor:
            raise ImportError("'set_diag' requires 'torch-sparse'")

        @staticmethod
        def fill_diag(src: SparseTensor, fill_value: float,
                      k: int = 0) -> SparseTensor:
            raise ImportError("'fill_diag' requires 'torch-sparse'")

        @staticmethod
        def masked_select_nnz(src: SparseTensor, mask: Tensor,
                              layout: Optional[str] = None) -> SparseTensor:
            raise ImportError("'masked_select_nnz' requires 'torch-sparse'")


try:
    import torch_frame  # noqa
    WITH_TORCH_FRAME = True
    from torch_frame import TensorFrame
except Exception:
    torch_frame = object
    WITH_TORCH_FRAME = False

    class TensorFrame:  # type: ignore
        pass


try:
    import intel_extension_for_pytorch  # noqa
    WITH_IPEX = True
except Exception:
    WITH_IPEX = False


class MockTorchCSCTensor:
    def __init__(
        self,
        edge_index: Tensor,
        edge_attr: Optional[Tensor] = None,
        size: Optional[Union[int, Tuple[int, int]]] = None,
    ):
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.size = size

    def t(self) -> Tensor:  # Only support accessing its transpose:
        from torch_geometric.utils import to_torch_csr_tensor
        size = self.size
        return to_torch_csr_tensor(
            self.edge_index.flip([0]),
            self.edge_attr,
            size[::-1] if isinstance(size, (tuple, list)) else size,
        )


# Types for accessing data ####################################################

# Node-types are denoted by a single string, e.g.: `data['paper']`:
NodeType = str

# Edge-types are denotes by a triplet of strings, e.g.:
# `data[('author', 'writes', 'paper')]
EdgeType = Tuple[str, str, str]

NodeOrEdgeType = Union[NodeType, EdgeType]

DEFAULT_REL = 'to'
EDGE_TYPE_STR_SPLIT = '__'


class EdgeTypeStr(str):
    r"""A helper class to construct serializable edge types by merging an edge
    type tuple into a single string.
    """
    def __new__(cls, *args: Any) -> 'EdgeTypeStr':
        if isinstance(args[0], (list, tuple)):
            # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
            args = tuple(args[0])

        if len(args) == 1 and isinstance(args[0], str):
            arg = args[0]  # An edge type string was passed.

        elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
            # A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
            arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1]))

        elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
            # A `(src, rel, dst)` edge type was passed:
            arg = EDGE_TYPE_STR_SPLIT.join(args)

        else:
            raise ValueError(f"Encountered invalid edge type '{args}'")

        return str.__new__(cls, arg)

    def to_tuple(self) -> EdgeType:
        r"""Returns the original edge type."""
        out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
        if len(out) != 3:
            raise ValueError(f"Cannot convert the edge type '{self}' to a "
                             f"tuple since it holds invalid characters")
        return out


# There exist some short-cuts to query edge-types (given that the full triplet
# can be uniquely reconstructed, e.g.:
# * via str: `data['writes']`
# * via Tuple[str, str]: `data[('author', 'paper')]`
QueryType = Union[NodeType, EdgeType, str, Tuple[str, str]]

Metadata = Tuple[List[NodeType], List[EdgeType]]

# A representation of a feature tensor
FeatureTensorType = Union[Tensor, np.ndarray]

# A representation of an edge index, following the possible formats:
#   * COO: (row, col)
#   * CSC: (row, colptr)
#   * CSR: (rowptr, col)
EdgeTensorType = Tuple[Tensor, Tensor]

# Types for message passing ###################################################

Adj = Union[Tensor, SparseTensor]
OptTensor = Optional[Tensor]
PairTensor = Tuple[Tensor, Tensor]
OptPairTensor = Tuple[Tensor, Optional[Tensor]]
PairOptTensor = Tuple[Optional[Tensor], Optional[Tensor]]
Size = Optional[Tuple[int, int]]
NoneType = Optional[Tensor]

MaybeHeteroNodeTensor = Union[Tensor, Dict[NodeType, Tensor]]
MaybeHeteroAdjTensor = Union[Tensor, Dict[EdgeType, Adj]]
MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]]

# Types for sampling ##########################################################

InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]
InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]

# Serialization ###############################################################

if WITH_PT24:
    torch.serialization.add_safe_globals([
        SparseTensor,
        SparseStorage,
        TensorFrame,
        MockTorchCSCTensor,
        EdgeTypeStr,
    ])