File: half_hop.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (100 lines) | stat: -rw-r--r-- 4,102 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


@functional_transform('half_hop')
class HalfHop(BaseTransform):
    r"""The graph upsampling augmentation from the
    `"Half-Hop: A Graph Upsampling Approach for Slowing Down Message Passing"
    <https://openreview.net/forum?id=lXczFIwQkv>`_ paper.
    The graph is augmented by adding artificial slow nodes between neighbors
    to slow down message propagation. (functional name: :obj:`half_hop`).

    .. note::
        :class:`HalfHop` augmentation is not supported if :obj:`data` has
        :attr:`edge_weight` or :attr:`edge_attr`.

    Args:
        alpha (float, optional): The interpolation factor
            used to compute slow node features
            :math:`x = \alpha*x_src + (1-\alpha)*x_dst` (default: :obj:`0.5`)
        p (float, optional): The probability of half-hopping
            an edge. (default: :obj:`1.0`)

    .. code-block:: python

        import torch_geometric.transforms as T

        transform = T.HalfHop(alpha=0.5)
        data = transform(data)  # Apply transformation.
        out = model(data.x, data.edge_index)  # Feed-forward.
        out = out[~data.slow_node_mask]  # Get rid of slow nodes.
    """
    def __init__(self, alpha: float = 0.5, p: float = 1.0) -> None:
        if alpha < 0. or alpha > 1.:
            raise ValueError(f"Interpolation factor has to be between 0 and 1 "
                             f"(got '{alpha}'")
        if p < 0. or p > 1.:
            raise ValueError(f"Ratio of half-hopped edges has to be between "
                             f"0 and 1 (got '{p}'")

        self.p = p
        self.alpha = alpha

    def forward(self, data: Data) -> Data:
        if data.edge_weight is not None or data.edge_attr is not None:
            raise ValueError("'HalfHop' augmentation is not supported if "
                             "'data' contains 'edge_weight' or 'edge_attr'")

        assert data.x is not None
        assert data.edge_index is not None
        x, edge_index = data.x, data.edge_index
        num_nodes = data.num_nodes
        assert num_nodes is not None

        # isolate self loops which are not half-hopped
        self_loop_mask = edge_index[0] == edge_index[1]
        edge_index_self_loop = edge_index[:, self_loop_mask]
        edge_index = edge_index[:, ~self_loop_mask]

        # randomly sample nodes and half-hop their edges
        node_mask = torch.rand(num_nodes, device=x.device) < self.p
        edge_mask = node_mask[edge_index[1]]
        edge_index_to_halfhop = edge_index[:, edge_mask]
        edge_index_to_keep = edge_index[:, ~edge_mask]

        # add new slow nodes of which features are initialized
        # by linear interpolation
        num_halfhop_edges = edge_index_to_halfhop.size(1)
        slow_node_ids = torch.arange(num_halfhop_edges,
                                     device=x.device) + num_nodes
        x_src = x[edge_index_to_halfhop[0]]
        x_dst = x[edge_index_to_halfhop[1]]
        x_slow_node = self.alpha * x_src + (1 - self.alpha) * x_dst
        new_x = torch.cat([x, x_slow_node], dim=0)

        # add new edges between slow nodes and the original nodes
        edge_index_slow = [
            torch.stack([edge_index_to_halfhop[0], slow_node_ids]),
            torch.stack([slow_node_ids, edge_index_to_halfhop[1]]),
            torch.stack([edge_index_to_halfhop[1], slow_node_ids])
        ]
        new_edge_index = torch.cat(
            [edge_index_to_keep, edge_index_self_loop, *edge_index_slow],
            dim=1)

        # prepare a mask that distinguishes between original nodes & slow nodes
        slow_node_mask = torch.cat(
            [x.new_zeros(x.size(0)),
             x.new_ones(slow_node_ids.size(0))], dim=0).bool()

        data.x, data.edge_index = new_x, new_edge_index
        data.slow_node_mask = slow_node_mask

        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}(alpha={self.alpha}, p={self.p})'