File: shadow.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (106 lines) | stat: -rw-r--r-- 4,173 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import copy
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.data import Batch, Data
from torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor


class ShaDowKHopSampler(torch.utils.data.DataLoader):
    r"""The ShaDow :math:`k`-hop sampler from the `"Decoupling the Depth and
    Scope of Graph Neural Networks" <https://arxiv.org/abs/2201.07858>`_ paper.
    Given a graph in a :obj:`data` object, the sampler will create shallow,
    localized subgraphs.
    A deep GNN on this local graph then smooths the informative local signals.

    .. note::

        For an example of using :class:`ShaDowKHopSampler`, see
        `examples/shadow.py <https://github.com/pyg-team/
        pytorch_geometric/blob/master/examples/shadow.py>`_.

    Args:
        data (torch_geometric.data.Data): The graph data object.
        depth (int): The depth/number of hops of the localized subgraph.
        num_neighbors (int): The number of neighbors to sample for each node in
            each hop.
        node_idx (LongTensor or BoolTensor, optional): The nodes that should be
            considered for creating mini-batches.
            If set to :obj:`None`, all nodes will be
            considered.
        replace (bool, optional): If set to :obj:`True`, will sample neighbors
            with replacement. (default: :obj:`False`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or
            :obj:`num_workers`.
    """
    def __init__(self, data: Data, depth: int, num_neighbors: int,
                 node_idx: Optional[Tensor] = None, replace: bool = False,
                 **kwargs):

        if not WITH_TORCH_SPARSE:
            raise ImportError(
                f"'{self.__class__.__name__}' requires 'torch-sparse'")

        self.data = copy.copy(data)
        self.depth = depth
        self.num_neighbors = num_neighbors
        self.replace = replace

        if data.edge_index is not None:
            self.is_sparse_tensor = False
            row, col = data.edge_index.cpu()
            self.adj_t = SparseTensor(
                row=row, col=col, value=torch.arange(col.size(0)),
                sparse_sizes=(data.num_nodes, data.num_nodes)).t()
        else:
            self.is_sparse_tensor = True
            self.adj_t = data.adj_t.cpu()

        if node_idx is None:
            node_idx = torch.arange(self.adj_t.sparse_size(0))
        elif node_idx.dtype == torch.bool:
            node_idx = node_idx.nonzero(as_tuple=False).view(-1)
        self.node_idx = node_idx

        super().__init__(node_idx.tolist(), collate_fn=self.__collate__,
                         **kwargs)

    def __collate__(self, n_id):
        n_id = torch.tensor(n_id)

        rowptr, col, value = self.adj_t.csr()
        out = torch.ops.torch_sparse.ego_k_hop_sample_adj(
            rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)
        rowptr, col, n_id, e_id, ptr, root_n_id = out

        adj_t = SparseTensor(rowptr=rowptr, col=col,
                             value=value[e_id] if value is not None else None,
                             sparse_sizes=(n_id.numel(), n_id.numel()),
                             is_sorted=True, trust_data=True)

        batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),
                      ptr=ptr)
        batch.root_n_id = root_n_id

        if self.is_sparse_tensor:
            batch.adj_t = adj_t
        else:
            row, col, e_id = adj_t.t().coo()
            batch.edge_index = torch.stack([row, col], dim=0)

        for k, v in self.data:
            if k in ['edge_index', 'adj_t', 'num_nodes', 'batch', 'ptr']:
                continue
            if k == 'y' and v.size(0) == self.data.num_nodes:
                batch[k] = v[n_id][root_n_id]
            elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
                batch[k] = v[n_id]
            elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:
                batch[k] = v[e_id]
            else:
                batch[k] = v

        return batch