File: create_gnn.rst

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (239 lines) | stat: -rw-r--r-- 13,226 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
Creating Message Passing Networks
=================================

Generalizing the convolution operator to irregular domains is typically expressed as a *neighborhood aggregation* or *message passing* scheme.
With :math:`\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F` denoting node features of node :math:`i` in layer :math:`(k-1)` and :math:`\mathbf{e}_{j,i} \in \mathbb{R}^D` denoting (optional) edge features from node :math:`j` to node :math:`i`, message passing graph neural networks can be described as

.. math::
  \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right),

where :math:`\bigoplus` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and :math:`\gamma` and :math:`\phi` denote differentiable functions such as MLPs (Multi Layer Perceptrons).

.. contents::
    :local:

The "MessagePassing" Base Class
-------------------------------

:pyg:`PyG` provides the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation.
The user only has to define the functions :math:`\phi` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, and :math:`\gamma` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`, as well as the aggregation scheme to use, *i.e.* :obj:`aggr="add"`, :obj:`aggr="mean"` or :obj:`aggr="max"`.

This is done with the help of the following methods:

* :obj:`MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)`: Defines the aggregation scheme to use (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`) and the flow direction of message passing (either :obj:`"source_to_target"` or :obj:`"target_to_source"`).
  Furthermore, the :obj:`node_dim` attribute indicates along which axis to propagate.
* :obj:`MessagePassing.propagate(edge_index, size=None, **kwargs)`:
  The initial call to start propagating messages.
  Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings.
  Note that :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate` is not limited to exchanging messages in square adjacency matrices of shape :obj:`[N, N]` only, but can also exchange messages in general sparse assignment matrices, *e.g.*, bipartite graphs, of shape :obj:`[N, M]` by passing :obj:`size=(N, M)` as an additional argument.
  If set to :obj:`None`, the assignment matrix is assumed to be a square matrix.
  For bipartite graphs with two independent sets of nodes and indices, and each set holding its own information, this split can be marked by passing the information as a tuple, *e.g.* :obj:`x=(x_N, x_M)`.
* :obj:`MessagePassing.message(...)`: Constructs messages to node :math:`i` in analogy to :math:`\phi` for each edge :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`.
  Can take any argument which was initially passed to :meth:`propagate`.
  In addition, tensors passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *e.g.* :obj:`x_i` and :obj:`x_j`.
  Note that we generally refer to :math:`i` as the central nodes that aggregates information, and refer to :math:`j` as the neighboring nodes, since this is the most common notation.
* :obj:`MessagePassing.update(aggr_out, ...)`: Updates node embeddings in analogy to :math:`\gamma` for each node :math:`i \in \mathcal{V}`.
  Takes in the output of aggregation as first argument and any argument which was initially passed to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`.

Let us verify this by re-implementing two popular GNN variants, the `GCN layer from Kipf and Welling <https://arxiv.org/abs/1609.02907>`_ and the `EdgeConv layer from Wang et al. <https://arxiv.org/abs/1801.07829>`_.

Implementing the GCN Layer
--------------------------

The `GCN layer <https://arxiv.org/abs/1609.02907>`_ is mathematically defined as

.. math::

    \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b},

where neighboring node features are first transformed by a weight matrix :math:`\mathbf{W}`, normalized by their degree, and finally summed up.
Lastly, we apply the bias vector :math:`\mathbf{b}` to the aggregated output.
This formula can be divided into the following steps:

1. Add self-loops to the adjacency matrix.
2. Linearly transform node feature matrix.
3. Compute normalization coefficients.
4. Normalize node features in :math:`\phi`.
5. Sum up neighboring node features (:obj:`"add"` aggregation).
6. Apply a final bias vector.

Steps 1-3 are typically computed before message passing takes place.
Steps 4-5 can be easily processed using the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class.
The full layer implementation is shown below:

.. code-block:: python

    import torch
    from torch.nn import Linear, Parameter
    from torch_geometric.nn import MessagePassing
    from torch_geometric.utils import add_self_loops, degree

    class GCNConv(MessagePassing):
        def __init__(self, in_channels, out_channels):
            super().__init__(aggr='add')  # "Add" aggregation (Step 5).
            self.lin = Linear(in_channels, out_channels, bias=False)
            self.bias = Parameter(torch.empty(out_channels))

            self.reset_parameters()

        def reset_parameters(self):
            self.lin.reset_parameters()
            self.bias.data.zero_()

        def forward(self, x, edge_index):
            # x has shape [N, in_channels]
            # edge_index has shape [2, E]

            # Step 1: Add self-loops to the adjacency matrix.
            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

            # Step 2: Linearly transform node feature matrix.
            x = self.lin(x)

            # Step 3: Compute normalization.
            row, col = edge_index
            deg = degree(col, x.size(0), dtype=x.dtype)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

            # Step 4-5: Start propagating messages.
            out = self.propagate(edge_index, x=x, norm=norm)

            # Step 6: Apply a final bias vector.
            out = out + self.bias

            return out

        def message(self, x_j, norm):
            # x_j has shape [E, out_channels]

            # Step 4: Normalize node features.
            return norm.view(-1, 1) * x_j

:class:`~torch_geometric.nn.conv.GCNConv` inherits from :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` with :obj:`"add"` propagation.
All the logic of the layer takes place in its :meth:`forward` method.
Here, we first add self-loops to our edge indices using the :meth:`torch_geometric.utils.add_self_loops` function (step 1), as well as linearly transform node features by calling the :class:`torch.nn.Linear` instance (step 2).

The normalization coefficients are derived by the node degrees :math:`\deg(i)` for each node :math:`i` which gets transformed to :math:`1/(\sqrt{\deg(i)} \cdot \sqrt{\deg(j)})` for each edge :math:`(j,i) \in \mathcal{E}`.
The result is saved in the tensor :obj:`norm` of shape :obj:`[num_edges, ]` (step 3).

We then call :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`, which internally calls :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.aggregate` and :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`.
We pass the node embeddings :obj:`x` and the normalization coefficients :obj:`norm` as additional arguments for message propagation.

In the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we need to normalize the neighboring node features :obj:`x_j` by :obj:`norm`.
Here, :obj:`x_j` denotes a *lifted* tensor, which contains the source node features of each edge, *i.e.*, the neighbors of each node.
Node features can be automatically lifted by appending :obj:`_i` or :obj:`_j` to the variable name.
In fact, any tensor can be converted this way, as long as they hold source or destination node features.

That is all that it takes to create a simple message passing layer.
You can use this layer as a building block for deep architectures.
Initializing and calling it is straightforward:

.. code-block:: python

    conv = GCNConv(16, 32)
    x = conv(x, edge_index)

Implementing the Edge Convolution
---------------------------------

The `edge convolutional layer <https://arxiv.org/abs/1801.07829>`_ processes graphs or point clouds and is mathematically defined as

.. math::

    \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right),

where :math:`h_{\mathbf{\Theta}}` denotes an MLP.
In analogy to the GCN layer, we can use the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` class to implement this layer, this time using the :obj:`"max"` aggregation:

.. code-block:: python

    import torch
    from torch.nn import Sequential as Seq, Linear, ReLU
    from torch_geometric.nn import MessagePassing

    class EdgeConv(MessagePassing):
        def __init__(self, in_channels, out_channels):
            super().__init__(aggr='max') #  "Max" aggregation.
            self.mlp = Seq(Linear(2 * in_channels, out_channels),
                           ReLU(),
                           Linear(out_channels, out_channels))

        def forward(self, x, edge_index):
            # x has shape [N, in_channels]
            # edge_index has shape [2, E]

            return self.propagate(edge_index, x=x)

        def message(self, x_i, x_j):
            # x_i has shape [E, in_channels]
            # x_j has shape [E, in_channels]

            tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
            return self.mlp(tmp)

Inside the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we use :obj:`self.mlp` to transform both the target node features :obj:`x_i` and the relative source node features :obj:`x_j - x_i` for each edge :math:`(j,i) \in \mathcal{E}`.

The edge convolution is actually a dynamic convolution, which recomputes the graph for each layer using nearest neighbors in the feature space.
Luckily, :pyg:`PyG` comes with a GPU accelerated batch-wise k-NN graph generation method named :meth:`torch_geometric.nn.pool.knn_graph`:

.. code-block:: python

    from torch_geometric.nn import knn_graph

    class DynamicEdgeConv(EdgeConv):
        def __init__(self, in_channels, out_channels, k=6):
            super().__init__(in_channels, out_channels)
            self.k = k

        def forward(self, x, batch=None):
            edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
            return super().forward(x, edge_index)

Here, :meth:`~torch_geometric.nn.pool.knn_graph` computes a nearest neighbor graph, which is further used to call the :meth:`forward` method of :class:`~torch_geometric.nn.conv.EdgeConv`.

This leaves us with a clean interface for initializing and calling this layer:

.. code-block:: python

    conv = DynamicEdgeConv(3, 128, k=6)
    x = conv(x, batch)

Exercises
---------

Imagine we are given the following :obj:`~torch_geometric.data.Data` object:

.. code-block:: python

    import torch
    from torch_geometric.data import Data

    edge_index = torch.tensor([[0, 1],
                               [1, 0],
                               [1, 2],
                               [2, 1]], dtype=torch.long)
    x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

    data = Data(x=x, edge_index=edge_index.t().contiguous())

Try to answer the following questions related to :class:`~torch_geometric.nn.conv.GCNConv`:

1. What information does :obj:`row` and :obj:`col` hold?

2. What does :meth:`~torch_geometric.utils.degree` do?

3. Why do we use :obj:`degree(col, ...)` rather than :obj:`degree(row, ...)`?

4. What does :obj:`deg_inv_sqrt[col]` and :obj:`deg_inv_sqrt[row]` do?

5. What information does :obj:`x_j` hold in the :meth:`~torch_geometric.nn.conv.MessagePassing.message` function? If :obj:`self.lin` denotes the identity function, what is the exact content of :obj:`x_j`?

6. Add an :meth:`~torch_geometric.nn.conv.MessagePassing.update` function to :class:`~torch_geometric.nn.conv.GCNConv` that adds transformed central node features to the aggregated output.

Try to answer the following questions related to :class:`~torch_geometric.nn.conv.EdgeConv`:

1. What is :obj:`x_i` and :obj:`x_j - x_i`?

2. What does :obj:`torch.cat([x_i, x_j - x_i], dim=1)` do? Why :obj:`dim = 1`?