File: utils.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (415 lines) | stat: -rw-r--r-- 14,901 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
import copy
import logging
import math
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import (
    Data,
    FeatureStore,
    GraphStore,
    HeteroData,
    TensorAttr,
    remote_backend_utils,
)
from torch_geometric.data.storage import EdgeStorage, NodeStorage
from torch_geometric.typing import (
    EdgeType,
    FeatureTensorType,
    InputEdges,
    InputNodes,
    NodeType,
    OptTensor,
    SparseTensor,
    TensorFrame,
)


def index_select(
    value: FeatureTensorType,
    index: Tensor,
    dim: int = 0,
) -> Tensor:
    r"""Indexes the :obj:`value` tensor along dimension :obj:`dim` using the
    entries in :obj:`index`.

    Args:
        value (torch.Tensor or np.ndarray): The input tensor.
        index (torch.Tensor): The 1-D tensor containing the indices to index.
        dim (int, optional): The dimension in which to index.
            (default: :obj:`0`)

    .. warning::

        :obj:`index` is casted to a :obj:`torch.int64` tensor internally, as
        `PyTorch currently only supports indexing
        <https://github.com/pytorch/pytorch/issues/61819>`_ via
        :obj:`torch.int64`.
    """
    # PyTorch currently only supports indexing via `torch.int64`:
    # https://github.com/pytorch/pytorch/issues/61819
    index = index.to(torch.int64)

    if isinstance(value, Tensor):
        out: Optional[Tensor] = None
        if torch.utils.data.get_worker_info() is not None:
            # If we are in a background process, we write directly into a
            # shared memory tensor to avoid an extra copy:
            size = list(value.shape)
            size[dim] = index.numel()
            numel = math.prod(size)
            if torch_geometric.typing.WITH_PT20:
                storage = value.untyped_storage()._new_shared(
                    numel * value.element_size())
            else:
                storage = value.storage()._new_shared(numel)
            out = value.new(storage).view(size)

        return torch.index_select(value, dim, index, out=out)

    if isinstance(value, TensorFrame):
        assert dim == 0
        return value[index]

    elif isinstance(value, np.ndarray):
        return torch.from_numpy(np.take(value, index, axis=dim))

    raise ValueError(f"Encountered invalid feature tensor type "
                     f"(got '{type(value)}')")


def filter_node_store_(store: NodeStorage, out_store: NodeStorage,
                       index: Tensor):
    # Filters a node storage object to only hold the nodes in `index`:
    for key, value in store.items():
        if key == 'num_nodes':
            out_store.num_nodes = index.numel()

        elif store.is_node_attr(key):
            if isinstance(value, (Tensor, TensorFrame)):
                index = index.to(value.device)
            elif isinstance(value, np.ndarray):
                index = index.cpu()
            dim = store._parent().__cat_dim__(key, value, store)
            out_store[key] = index_select(value, index, dim=dim)


def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
                       col: Tensor, index: OptTensor, perm: OptTensor = None):
    # Filters a edge storage object to only hold the edges in `index`,
    # which represents the new graph as denoted by `(row, col)`:
    for key, value in store.items():
        if key == 'edge_index':
            edge_index = torch.stack([row, col], dim=0).to(value.device)
            # TODO Integrate `EdgeIndex` into `custom_store`.
            # edge_index = EdgeIndex(
            #     torch.stack([row, col], dim=0).to(value.device),
            #     sparse_size=out_store.size(),
            #     sort_order='col',
            #     # TODO Support `is_undirected`.
            # )
            out_store.edge_index = edge_index

        elif key == 'adj_t':
            # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
            row = row.to(value.device())
            col = col.to(value.device())
            edge_attr = value.storage.value()
            if edge_attr is not None:
                if index is not None:
                    index = index.to(edge_attr.device)
                    edge_attr = index_select(edge_attr, index, dim=0)
                else:
                    edge_attr = None
            sparse_sizes = out_store.size()[::-1]
            # TODO Currently, we set `is_sorted=False`, see:
            # https://github.com/pyg-team/pytorch_geometric/issues/4346
            out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr,
                                           sparse_sizes=sparse_sizes,
                                           is_sorted=False, trust_data=True)

        elif store.is_edge_attr(key):
            if index is None:
                out_store[key] = None
                continue

            dim = store._parent().__cat_dim__(key, value, store)
            if isinstance(value, (Tensor, TensorFrame)):
                index = index.to(value.device)
            elif isinstance(value, np.ndarray):
                index = index.cpu()
            if perm is None:
                out_store[key] = index_select(value, index, dim=dim)
            else:
                if isinstance(value, (Tensor, TensorFrame)):
                    perm = perm.to(value.device)
                elif isinstance(value, np.ndarray):
                    perm = perm.cpu()
                out_store[key] = index_select(
                    value,
                    perm[index.to(torch.int64)],
                    dim=dim,
                )


def filter_data(data: Data, node: Tensor, row: Tensor, col: Tensor,
                edge: OptTensor, perm: OptTensor = None) -> Data:
    # Filters a data object to only hold nodes in `node` and edges in `edge`:
    out = copy.copy(data)
    filter_node_store_(data._store, out._store, node)
    filter_edge_store_(data._store, out._store, row, col, edge, perm)
    return out


def filter_hetero_data(
    data: HeteroData,
    node_dict: Dict[NodeType, Tensor],
    row_dict: Dict[EdgeType, Tensor],
    col_dict: Dict[EdgeType, Tensor],
    edge_dict: Dict[EdgeType, OptTensor],
    perm_dict: Optional[Dict[EdgeType, OptTensor]] = None,
) -> HeteroData:
    # Filters a heterogeneous data object to only hold nodes in `node` and
    # edges in `edge` for each node and edge type, respectively:
    out = copy.copy(data)

    for node_type in out.node_types:
        # Handle the case of disconneted graph sampling:
        if node_type not in node_dict:
            node_dict[node_type] = torch.empty(0, dtype=torch.long)

        filter_node_store_(data[node_type], out[node_type],
                           node_dict[node_type])

    for edge_type in out.edge_types:
        # Handle the case of disconneted graph sampling:
        if edge_type not in row_dict:
            row_dict[edge_type] = torch.empty(0, dtype=torch.long)
        if edge_type not in col_dict:
            col_dict[edge_type] = torch.empty(0, dtype=torch.long)
        if edge_type not in edge_dict:
            edge_dict[edge_type] = torch.empty(0, dtype=torch.long)

        filter_edge_store_(
            data[edge_type],
            out[edge_type],
            row_dict[edge_type],
            col_dict[edge_type],
            edge_dict[edge_type],
            perm_dict.get(edge_type, None) if perm_dict else None,
        )

    return out


def filter_custom_store(
    feature_store: FeatureStore,
    graph_store: GraphStore,
    node: Tensor,
    row: Tensor,
    col: Tensor,
    edge: OptTensor,
    custom_cls: Optional[Data] = None,
) -> Data:
    r"""Constructs a :class:`~torch_geometric.data.Data` object from a feature
    store and graph store instance.
    """
    # Construct a new `Data` object:
    data = custom_cls() if custom_cls is not None else Data()

    data.edge_index = torch.stack([row, col], dim=0)

    # Filter node storage:
    required_attrs = []
    for attr in feature_store.get_all_tensor_attrs():
        attr.index = node  # TODO Support edge features.
        required_attrs.append(attr)
        data.num_nodes = attr.index.size(0)

    # NOTE Here, we utilize `feature_store.multi_get` to give the feature store
    # full control over optimizing how it returns features (since the call is
    # synchronous, this amounts to giving the feature store control over all
    # iteration).
    tensors = feature_store.multi_get_tensor(required_attrs)
    for i, attr in enumerate(required_attrs):
        data[attr.attr_name] = tensors[i]

    return data


def filter_custom_hetero_store(
    feature_store: FeatureStore,
    graph_store: GraphStore,
    node_dict: Dict[str, Tensor],
    row_dict: Dict[str, Tensor],
    col_dict: Dict[str, Tensor],
    edge_dict: Dict[str, OptTensor],
    custom_cls: Optional[HeteroData] = None,
) -> HeteroData:
    r"""Constructs a :class:`~torch_geometric.data.HeteroData` object from a
    feature store and graph store instance.
    """
    # Construct a new `HeteroData` object:
    data = custom_cls() if custom_cls is not None else HeteroData()

    # Filter edge storage:
    # TODO support edge attributes
    for attr in graph_store.get_all_edge_attrs():
        key = attr.edge_type
        if key in row_dict and key in col_dict:
            edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)
            data[attr.edge_type].edge_index = edge_index

    # Filter node storage:
    required_attrs = []
    for attr in feature_store.get_all_tensor_attrs():
        if attr.group_name in node_dict:
            attr.index = node_dict[attr.group_name]
            required_attrs.append(attr)
            data[attr.group_name].num_nodes = attr.index.size(0)

    # NOTE Here, we utilize `feature_store.multi_get` to give the feature store
    # full control over optimizing how it returns features (since the call is
    # synchronous, this amounts to giving the feature store control over all
    # iteration).
    tensors = feature_store.multi_get_tensor(required_attrs)
    for i, attr in enumerate(required_attrs):
        data[attr.group_name][attr.attr_name] = tensors[i]

    return data


# Input Utilities #############################################################


def get_input_nodes(
    data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
    input_nodes: Union[InputNodes, TensorAttr],
    input_id: Optional[Tensor] = None,
) -> Tuple[Optional[str], Tensor, Optional[Tensor]]:
    def to_index(nodes, input_id) -> Tuple[Tensor, Optional[Tensor]]:
        if isinstance(nodes, Tensor) and nodes.dtype == torch.bool:
            nodes = nodes.nonzero(as_tuple=False).view(-1)
            if input_id is not None:
                assert input_id.numel() == nodes.numel()
            else:
                input_id = nodes
            return nodes, input_id

        if not isinstance(nodes, Tensor):
            nodes = torch.tensor(nodes, dtype=torch.long)

        if input_id is not None:
            assert input_id.numel() == nodes.numel()

        return nodes, input_id

    if isinstance(data, Data):
        if input_nodes is None:
            return None, torch.arange(data.num_nodes), None
        return None, *to_index(input_nodes, input_id)

    elif isinstance(data, HeteroData):
        assert input_nodes is not None

        if isinstance(input_nodes, str):
            return input_nodes, torch.arange(data[input_nodes].num_nodes), None

        assert isinstance(input_nodes, (list, tuple))
        assert len(input_nodes) == 2
        assert isinstance(input_nodes[0], str)

        node_type, input_nodes = input_nodes
        if input_nodes is None:
            return node_type, torch.arange(data[node_type].num_nodes), None
        return node_type, *to_index(input_nodes, input_id)

    else:  # Tuple[FeatureStore, GraphStore]
        feature_store, graph_store = data
        assert input_nodes is not None

        if isinstance(input_nodes, Tensor):
            return None, *to_index(input_nodes, input_id)

        if isinstance(input_nodes, str):
            num_nodes = remote_backend_utils.num_nodes(  #
                feature_store, graph_store, input_nodes)
            return input_nodes, torch.arange(num_nodes), None

        if isinstance(input_nodes, (list, tuple)):
            assert len(input_nodes) == 2
            assert isinstance(input_nodes[0], str)

            node_type, input_nodes = input_nodes
            if input_nodes is None:
                num_nodes = remote_backend_utils.num_nodes(  #
                    feature_store, graph_store, input_nodes)
                return node_type, torch.arange(num_nodes), None

            return node_type, *to_index(input_nodes, input_id)


def get_edge_label_index(
    data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],
    edge_label_index: InputEdges,
) -> Tuple[Optional[str], Tensor]:
    edge_type = None
    if isinstance(data, Data):
        if edge_label_index is None:
            return None, data.edge_index
        return None, edge_label_index

    assert edge_label_index is not None
    assert isinstance(edge_label_index, (list, tuple))

    if isinstance(data, HeteroData):
        if isinstance(edge_label_index[0], str):
            edge_type = edge_label_index
            edge_type = data._to_canonical(*edge_type)
            assert edge_type in data.edge_types
            return edge_type, data[edge_type].edge_index

        assert len(edge_label_index) == 2

        edge_type, edge_label_index = edge_label_index
        edge_type = data._to_canonical(*edge_type)

        if edge_label_index is None:
            return edge_type, data[edge_type].edge_index

        return edge_type, edge_label_index

    else:  # Tuple[FeatureStore, GraphStore]
        _, graph_store = data

        # Need the edge index in COO for LinkNeighborLoader:
        def _get_edge_index(edge_type):
            row_dict, col_dict, _ = graph_store.coo([edge_type])
            row = list(row_dict.values())[0]
            col = list(col_dict.values())[0]
            return torch.stack((row, col), dim=0)

        if isinstance(edge_label_index[0], str):
            edge_type = edge_label_index
            return edge_type, _get_edge_index(edge_type)

        assert len(edge_label_index) == 2
        edge_type, edge_label_index = edge_label_index

        if edge_label_index is None:
            return edge_type, _get_edge_index(edge_type)

        return edge_type, edge_label_index


def infer_filter_per_worker(data: Any) -> bool:
    out = True
    if isinstance(data, (Data, HeteroData)) and data.is_cuda:
        out = False
    logging.debug(f"Inferred 'filter_per_worker={out}' option for feature "
                  f"fetching routines of the data loader")
    return out