File: graph_transformer.rst

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (326 lines) | stat: -rw-r--r-- 13,640 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
Graph Transformer
=================

`Transformer <https://arxiv.org/abs/1706.03762>`_ is an effictive architecture in `natural language processing <https://arxiv.org/abs/1810.04805>`_ and `computer vision <https://arxiv.org/abs/2010.11929>`_.
Recently, there have been some applications(`Grover <https://arxiv.org/abs/2007.02835>`_, `GraphGPS <https://arxiv.org/abs/2205.12454>`_, etc) that combine transformers on graphs.
In this tutorial, we will present how to build a graph transformer model via :pyg:`PyG`. See `our webinar <https://youtu.be/wAYryx3GjLw?si=2vB7imfenP5tUvqd>`_ for in-depth learning on this topic.

.. note::
    Click `here <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py>`_ to download the full example code

Transformers on Graphs
------------------------------

Compared to Graph Transformers, MPNNs have several drawbacks: (1) WL test: 1-order MPNNs have limited expressivity; (2) Over-smoothing: the features tend to
converge to the same value while increasing the number of GNN layers; (3) Over-squashing: Losing information when trying to aggregate messages from many neighbors into a single vector;
(4) Cannot capture long-range dependencies.

Feeding the whole graph into the Transformer also brings several pros and cons

**Pros**

* Computation graph structure is decoupled from the input graph structure.
* Long-range connections can be handled as all nodes are connected to each other.

**Cons**

* Loss of inductive bias that enables GNNs to work so well on graphs with pronounced locality. Particularly in graphs where edges represent relatedness/closeness.
* Language input is squential, but graphs are permutation invariant to node ordering.
* Square computational complexity :math:`O(N^2)` in the number of nodes whereas message passing GNNs are linear in the number of edges :math:`O(E)`. Graphs are often sparse :math:`N \approx E`.

Attention
+++++++++

.. math::
    Q = XW_Q, K = XW_K, V = XW_V
.. math::
    Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

In Transformer, attention can be multi-head, which consists of multiple attention weights.

Positional and Structural Encodings
+++++++++++++++++++++++++++++++++++

We organized PE/SE into 3 categories based on their locality: (1) Local, (2) Global, (3) Relative.
Positional encodings (PE) provides an idea of the position in space of a given node within the graph. When two nodes are close to each other within a graph or subgraph, their PE should also be close.
Structure encodings (SE) provides an embedding of the structure of graphs or subgraphs to help increasing the expressivity and the generalizability of GNNs.
When two nodes share similar subgraphs, or when two graphs are similar, their SE should also be close.

.. list-table::
   :widths: 10 20 20
   :header-rows: 1

   * - Encoding type
     - Positional encodings (PE)
     - Structure encodings (SE)
   * - Local (node)
     - (1)Distance to cluster center; (2)Sum of non-diagonal elements in m-step random walk.
     - (1)Node degree; (2)Random walk diagonals; (3) Enumerate substructures(triangles, rings).
   * - Global (node)
     - (1)Eigenvectors of A/L or distance matrix; (2)Distance to graph's centroid; (3)Unique ID for each node.
     - (1)Eigenvalues of A/L; (2) Graph diameter, girth, degree, etc.
   * - Relative (edge)
     - (1)Pair-wise distance from: Heat Kernels, Random Walks, Graph geodesic, etc; (2)Gradient of eigenvectors
     - (1)Gradient of any Local SE; (2)Gradient of sub-structure enumeration

GPS Layer and GraphGPS Model
--------------------------------------

Firstly, we introduce the GPS layer, which is combined with local MPNN and global Transformer, and then followed by 2-layer MLP and skip-connecttions.
Local MPNN can provide locality bias that is difficult or expensive to achieve in Transformer.
In addition, features of edges can be updated and encoded into the node features(`GatedGCN <https://arxiv.org/abs/1711.07553>`_, `GINE <https://arxiv.org/abs/1905.12265>`_).
Transformer can utilize positional and structural encodings. As we don't need to consider edge features, We can use the existing linear Transformer architecture to reduce the time complexity from :math:`O(N^2)` to :math:`O(N + E)`, like `Performer <https://arxiv.org/abs/2009.14794>`_ and `BigBird <https://arxiv.org/abs/2007.14062>`_.

.. warning::
    `BigBird <https://arxiv.org/abs/2007.14062>`_ currently is not supported, will be added in the future.

.. figure:: ../_figures/graphgps_layer.png
    :align: center
    :width: 100%

The update function of each layer is described by the equations below.

Local MPNN
++++++++++

.. math::
    \hat{X}_M^{l + 1}, E^{l + 1} = MPNN_e^l(X^l, E^l, A)
.. math::
    X_M^{l + 1} = BatchNorm(Dropout(\hat{X}_M^{l + 1}) + X^l)

.. code-block:: python

    h = self.conv(x, edge_index, **kwargs)
    h = F.dropout(h, p=self.dropout, training=self.training)
    h = h + x
    if self.norm1 is not None:
        if self.norm_with_batch:
            h = self.norm1(h, batch=batch)
        else:
            h = self.norm1(h)
    hs.append(h)

Global Attention
++++++++++++++++

.. math::
    \hat{X}_T^{l + 1} = GlobalAttn^l(X^l)
.. math::
    X_T^{l + 1} = BatchNorm(Dropout(\hat{X}_T^{l + 1}) + X^l)

.. code-block:: python

    h, mask = to_dense_batch(x, batch)

    if isinstance(self.attn, torch.nn.MultiheadAttention):
        h, _ = self.attn(h, h, h, key_padding_mask=~mask,
                        need_weights=False)
    elif isinstance(self.attn, PerformerAttention):
        h = self.attn(h, mask=mask)

    h = h[mask]
    h = F.dropout(h, p=self.dropout, training=self.training)
    h = h + x  # Residual connection.
    if self.norm2 is not None:
        if self.norm_with_batch:
            h = self.norm2(h, batch=batch)
        else:
            h = self.norm2(h)
    hs.append(h)

Combine local and global outputs
++++++++++++++++++++++++++++++++

.. math::
    X^{l + 1} = MLP^l(X_M^{l + 1} + X_T^{l + 1})

.. code-block:: python

    out = sum(hs)

    out = out + self.mlp(out)
    if self.norm3 is not None:
        if self.norm_with_batch:
            out = self.norm3(out, batch=batch)
        else:
            out = self.norm3(out)

Next, we introduce GraphGPS architecture. The difference between `GraphGPS <https://arxiv.org/abs/2205.12454>`_ and `GraphTrans <https://arxiv.org/abs/2201.08821>`_ is the organization of MPNN and transformer.
In GraphTrans, a few layers of MPNNs are comprised before the Transformer, which may be limited by problems of over-smoothing, over-squashing and low expressivity against the WL test.
These layers could irreparably fail to keep some information in the early stage. The design of GraphGPS is a stacking of MPNN + transformer hybrid, which resolves
the local expressivity bottlenecks by allowing information to spread across the graph via full-connectivity.

Train GraphGPS on graph-structured data
--------------------------------------------------

In this part, we'll show how to train a :class:`~torch_geometric.nn.GPSConv` GNN model on the :class:`~torch_geometric.datasets.ZINC` dataset.

Load dataset
++++++++++++

.. code-block:: python

    transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
    train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)
    val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)
    test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)
    test_loader = DataLoader(test_dataset, batch_size=64)


Define model
++++++++++++

.. code-block:: python

    class RedrawProjection:
        def __init__(self, model: torch.nn.Module,
                    redraw_interval: Optional[int] = None):
            self.model = model
            self.redraw_interval = redraw_interval
            self.num_last_redraw = 0

        def redraw_projections(self):
            if not self.model.training or self.redraw_interval is None:
                return
            if self.num_last_redraw >= self.redraw_interval:
                fast_attentions = [
                    module for module in self.model.modules()
                    if isinstance(module, PerformerAttention)
                ]
                for fast_attention in fast_attentions:
                    fast_attention.redraw_projection_matrix()
                self.num_last_redraw = 0
                return
            self.num_last_redraw += 1

    class GPS(torch.nn.Module):
        def __init__(self, channels: int, pe_dim: int, num_layers: int,
                    attn_type: str, attn_kwargs: Dict[str, Any]):
            super().__init__()

            self.node_emb = Embedding(28, channels - pe_dim)
            self.pe_lin = Linear(20, pe_dim)
            self.pe_norm = BatchNorm1d(20)
            self.edge_emb = Embedding(4, channels)

            self.convs = ModuleList()
            for _ in range(num_layers):
                nn = Sequential(
                    Linear(channels, channels),
                    ReLU(),
                    Linear(channels, channels),
                )
                conv = GPSConv(channels, GINEConv(nn), heads=4,
                            attn_type=attn_type, attn_kwargs=attn_kwargs)
                self.convs.append(conv)

            self.mlp = Sequential(
                Linear(channels, channels // 2),
                ReLU(),
                Linear(channels // 2, channels // 4),
                ReLU(),
                Linear(channels // 4, 1),
            )
            self.redraw_projection = RedrawProjection(
                self.convs,
                redraw_interval=1000 if attn_type == 'performer' else None)

        def forward(self, x, pe, edge_index, edge_attr, batch):
            x_pe = self.pe_norm(pe)
            x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
            edge_attr = self.edge_emb(edge_attr)

            for conv in self.convs:
                x = conv(x, edge_index, batch, edge_attr=edge_attr)
            x = global_add_pool(x, batch)
            return self.mlp(x)



Train and evaluate
+++++++++++++++++++

.. code-block:: python

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    attn_kwargs = {'dropout': 0.5}
    model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type,
                attn_kwargs=attn_kwargs).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                                min_lr=0.00001)


    def train():
        model.train()

        total_loss = 0
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            model.redraw_projection.redraw_projections()
            out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                        data.batch)
            loss = (out.squeeze() - data.y).abs().mean()
            loss.backward()
            total_loss += loss.item() * data.num_graphs
            optimizer.step()
        return total_loss / len(train_loader.dataset)


    @torch.no_grad()
    def test(loader):
        model.eval()

        total_error = 0
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                        data.batch)
            total_error += (out.squeeze() - data.y).abs().sum().item()
        return total_error / len(loader.dataset)


    for epoch in range(1, 101):
        loss = train()
        val_mae = test(val_loader)
        test_mae = test(test_loader)
        scheduler.step(val_mae)
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
            f'Test: {test_mae:.4f}')

.. code-block:: text

    Epoch: 01, Loss: 0.7216, Val: 0.5316, Test: 0.5454
    Epoch: 02, Loss: 0.5519, Val: 0.5895, Test: 0.6288
    Epoch: 03, Loss: 0.5009, Val: 0.5029, Test: 0.4924
    Epoch: 04, Loss: 0.4751, Val: 0.4801, Test: 0.4786
    Epoch: 05, Loss: 0.4363, Val: 0.4438, Test: 0.4352
    Epoch: 06, Loss: 0.4276, Val: 0.4931, Test: 0.4994
    Epoch: 07, Loss: 0.3956, Val: 0.3502, Test: 0.3439
    Epoch: 08, Loss: 0.4021, Val: 0.3143, Test: 0.3296
    Epoch: 09, Loss: 0.3761, Val: 0.4012, Test: 0.3858
    Epoch: 10, Loss: 0.3739, Val: 0.3343, Test: 0.3032
    Epoch: 11, Loss: 0.3532, Val: 0.3679, Test: 0.3334
    Epoch: 12, Loss: 0.3683, Val: 0.3094, Test: 0.2754
    Epoch: 13, Loss: 0.3457, Val: 0.4007, Test: 0.4023
    Epoch: 14, Loss: 0.3460, Val: 0.3986, Test: 0.3589
    Epoch: 15, Loss: 0.3369, Val: 0.3478, Test: 0.3124
    Epoch: 16, Loss: 0.3222, Val: 0.3043, Test: 0.2651
    Epoch: 17, Loss: 0.3190, Val: 0.4496, Test: 0.4070
    Epoch: 18, Loss: 0.3317, Val: 0.3803, Test: 0.3450
    Epoch: 19, Loss: 0.3179, Val: 0.2671, Test: 0.2408
    Epoch: 20, Loss: 0.3143, Val: 0.4168, Test: 0.3901
    Epoch: 21, Loss: 0.3238, Val: 0.3183, Test: 0.2926
    Epoch: 22, Loss: 0.3132, Val: 0.9534, Test: 1.0879
    Epoch: 23, Loss: 0.3088, Val: 0.3705, Test: 0.3360
    Epoch: 24, Loss: 0.3032, Val: 0.3051, Test: 0.2692
    Epoch: 25, Loss: 0.2968, Val: 0.2829, Test: 0.2571
    Epoch: 26, Loss: 0.2915, Val: 0.3145, Test: 0.2820
    Epoch: 27, Loss: 0.2871, Val: 0.3127, Test: 0.2965
    Epoch: 28, Loss: 0.2953, Val: 0.4415, Test: 0.4144
    Epoch: 29, Loss: 0.2916, Val: 0.3118, Test: 0.2733
    Epoch: 30, Loss: 0.3074, Val: 0.4497, Test: 0.4418