1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
|
Memory-Efficient Aggregations
=============================
The :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface of :pyg:`PyG` relies on a gather-scatter scheme to aggregate messages from neighboring nodes.
For example, consider the message passing layer
.. math::
\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i),
that can be implemented as:
.. code-block:: python
from torch_geometric.nn import MessagePassing
x = ... # Node features of shape [num_nodes, num_features]
edge_index = ... # Edge indices of shape [2, num_edges]
class MyConv(MessagePassing):
def __init__(self):
super().__init__(aggr="add")
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
return MLP(x_j - x_i)
Under the hood, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementation produces a code that looks as follows:
.. code-block:: python
from torch_geometric.utils import scatter
x = ... # Node features of shape [num_nodes, num_features]
edge_index = ... # Edge indices of shape [2, num_edges]
x_j = x[edge_index[0]] # Source node features [num_edges, num_features]
x_i = x[edge_index[1]] # Target node features [num_edges, num_features]
msg = MLP(x_j - x_i) # Compute message for each edge
# Aggregate messages based on target node indices
out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')
While the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitely materalizing :obj:`x_j` and :obj:`x_i`, resulting in a high memory footprint on large and dense graphs.
Luckily, not all GNNs need to be implemented by explicitely materalizing :obj:`x_j` and/or :obj:`x_i`.
In some cases, GNNs can also be implemented as a simple-sparse matrix multiplication.
As a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:`x_i` or multi-dimensional edge features when computing messages.
For example, the :class:`~torch_geometric.nn.conv.GINConv` layer
.. math::
\mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),
is equivalent to computing
.. math::
\mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right),
where :math:`\mathbf{A}` denotes a sparse adjacency matrix of shape :obj:`[num_nodes, num_nodes]`.
This formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations.
In :pyg:`null` **PyG >= 1.6.0**, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a **lower memory footprint** and a **faster execution time**.
As a result, we introduce the :class:`SparseTensor` class (from the :obj:`torch_sparse` package), which implements fast forward and backward passes for sparse-matrix multiplication based on the `"Design Principles for Sparse Matrix Multiplication on the GPU" <https://arxiv.org/abs/1803.08601>`_ paper.
Using the :class:`SparseTensor` class is straightforward and similar to the way :obj:`scipy` treats sparse matrices:
.. code-block:: python
from torch_sparse import SparseTensor
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
sparse_sizes=(num_nodes, num_nodes))
# value is optional and can be None
# Obtain different representations (COO, CSR, CSC):
row, col, value = adj.coo()
rowptr, col, value = adj.csr()
colptr, row, value = adj.csc()
adj = adj[:100, :100] # Slicing, indexing and masking support
adj = adj.set_diag() # Add diagonal entries
adj_t = adj.t() # Transpose
out = adj.matmul(x) # Sparse-dense matrix multiplication
adj = adj.matmul(adj) # Sparse-sparse matrix multiplication
# Creating SparseTensor instances:
adj = SparseTensor.from_dense(mat)
adj = SparseTensor.eye(100, 100)
adj = SparseTensor.from_scipy(mat)
Our :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface can handle both :obj:`torch.Tensor` and :class:`SparseTensor` as input for propagating messages.
However, when holding a directed graph in :class:`SparseTensor`, you need to make sure to input the **transposed sparse matrix** to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`:
.. code-block:: python
conv = GCNConv(16, 32)
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)
conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32)))
out1 = conv(x, edge_index)
out2 = conv(x, adj.t())
assert torch.allclose(out1, out2)
To leverage sparse-matrix multiplications, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface introduces the :func:`~torch_geometric.nn.conv.message_passing.message_and_aggregate` function (which fuses the :func:`~torch_geometric.nn.conv.message_passing.message` and :func:`~torch_geometric.nn.conv.message_passing.aggregate` functions into a single computation step), which gets called whenever it is implemented and receives a :class:`SparseTensor` as input for :obj:`edge_index`.
With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemented as follows:
.. code-block:: python
import torch_sparse
class GINConv(MessagePassing):
def __init__(self):
super().__init__(aggr="add")
def forward(self, x, edge_index):
out = self.propagate(edge_index, x=x)
return MLP((1 + eps) x + out)
def message(self, x_j):
return x_j
def message_and_aggregate(self, adj_t, x):
return torch_sparse.matmul(adj_t, x, reduce=self.aggr)
Playing around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box.
To convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform:
.. code-block:: python
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())
data = dataset[0]
>>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)
class GNN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
self.conv2 = GCNConv(16, dataset.num_classes, cached=True)
def forward(self, x, adj_t):
x = self.conv1(x, adj_t)
x = F.relu(x)
x = self.conv2(x, adj_t)
return F.log_softmax(x, dim=1)
model = GNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(data):
model.train()
optimizer.zero_grad()
out = model(data.x, data.adj_t)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
return float(loss)
for epoch in range(1, 201):
loss = train(data)
All code remains the same as before, except for the :obj:`data` transform via :obj:`T.ToSparseTensor()`.
As an additional advantage, :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementations that utilize the :class:`SparseTensor` class are deterministic on the GPU since aggregations no longer rely on atomic operations.
Notably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:`edge_weight` or :obj:`edge_attr` into their message passing formulation, respectively.
In particular, it is now expected that these attributes are directly added as values to the :class:`SparseTensor` object.
Instead of calling the GNN as
.. code-block:: python
conv = GMMConv(16, 32, dim=3)
out = conv(x, edge_index, edge_attr)
we now execute our GNN operator as
.. code-block:: python
conv = GMMConv(16, 32, dim=3)
adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr)
out = conv(x, adj.t())
.. note::
Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format.
You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via:
.. code-block:: python
row, col, edge_attr = adj_t.t().coo()
edge_index = torch.stack([row, col], dim=0)
Please let us know what you think of :class:`SparseTensor`, how we can improve it, and whenever you encounter any unexpected behavior.
|