File: sparse_tensor.rst

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (206 lines) | stat: -rw-r--r-- 8,990 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
Memory-Efficient Aggregations
=============================

The :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface of :pyg:`PyG` relies on a gather-scatter scheme to aggregate messages from neighboring nodes.
For example, consider the message passing layer

.. math::

    \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i),

that can be implemented as:

.. code-block:: python

    from torch_geometric.nn import MessagePassing

    x = ...           # Node features of shape [num_nodes, num_features]
    edge_index = ...  # Edge indices of shape [2, num_edges]

    class MyConv(MessagePassing):
        def __init__(self):
            super().__init__(aggr="add")

        def forward(self, x, edge_index):
            return self.propagate(edge_index, x=x)

        def message(self, x_i, x_j):
            return MLP(x_j - x_i)

Under the hood, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementation produces a code that looks as follows:

.. code-block:: python

    from torch_geometric.utils import scatter

    x = ...           # Node features of shape [num_nodes, num_features]
    edge_index = ...  # Edge indices of shape [2, num_edges]

    x_j = x[edge_index[0]]  # Source node features [num_edges, num_features]
    x_i = x[edge_index[1]]  # Target node features [num_edges, num_features]

    msg = MLP(x_j - x_i)  # Compute message for each edge

    # Aggregate messages based on target node indices
    out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')

While the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitely materalizing :obj:`x_j` and :obj:`x_i`, resulting in a high memory footprint on large and dense graphs.

Luckily, not all GNNs need to be implemented by explicitely materalizing :obj:`x_j` and/or :obj:`x_i`.
In some cases, GNNs can also be implemented as a simple-sparse matrix multiplication.
As a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:`x_i` or multi-dimensional edge features when computing messages.
For example, the :class:`~torch_geometric.nn.conv.GINConv` layer

.. math::

    \mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right),

is equivalent to computing

.. math::

    \mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right),

where :math:`\mathbf{A}` denotes a sparse adjacency matrix of shape :obj:`[num_nodes, num_nodes]`.
This formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations.

In :pyg:`null` **PyG >= 1.6.0**, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a **lower memory footprint** and a **faster execution time**.
As a result, we introduce the :class:`SparseTensor` class (from the :obj:`torch_sparse` package), which implements fast forward and backward passes for sparse-matrix multiplication based on the `"Design Principles for Sparse Matrix Multiplication on the GPU" <https://arxiv.org/abs/1803.08601>`_ paper.

Using the :class:`SparseTensor` class is straightforward and similar to the way :obj:`scipy` treats sparse matrices:

.. code-block:: python

    from torch_sparse import SparseTensor

    adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,
                       sparse_sizes=(num_nodes, num_nodes))
    # value is optional and can be None

    # Obtain different representations (COO, CSR, CSC):
    row,    col, value = adj.coo()
    rowptr, col, value = adj.csr()
    colptr, row, value = adj.csc()

    adj = adj[:100, :100]  # Slicing, indexing and masking support
    adj = adj.set_diag()   # Add diagonal entries
    adj_t = adj.t()        # Transpose
    out = adj.matmul(x)    # Sparse-dense matrix multiplication
    adj = adj.matmul(adj)  # Sparse-sparse matrix multiplication

    # Creating SparseTensor instances:
    adj = SparseTensor.from_dense(mat)
    adj = SparseTensor.eye(100, 100)
    adj = SparseTensor.from_scipy(mat)

Our :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface can handle both :obj:`torch.Tensor` and :class:`SparseTensor` as input for propagating messages.
However, when holding a directed graph in :class:`SparseTensor`, you need to make sure to input the **transposed sparse matrix** to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`:

.. code-block:: python

    conv = GCNConv(16, 32)
    out1 = conv(x, edge_index)
    out2 = conv(x, adj.t())
    assert torch.allclose(out1, out2)

    conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32)))
    out1 = conv(x, edge_index)
    out2 = conv(x, adj.t())
    assert torch.allclose(out1, out2)

To leverage sparse-matrix multiplications, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface introduces the :func:`~torch_geometric.nn.conv.message_passing.message_and_aggregate` function (which fuses the :func:`~torch_geometric.nn.conv.message_passing.message` and :func:`~torch_geometric.nn.conv.message_passing.aggregate` functions into a single computation step), which gets called whenever it is implemented and receives a :class:`SparseTensor` as input for :obj:`edge_index`.
With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemented as follows:

.. code-block:: python

    import torch_sparse

    class GINConv(MessagePassing):
        def __init__(self):
            super().__init__(aggr="add")

        def forward(self, x, edge_index):
            out = self.propagate(edge_index, x=x)
            return MLP((1 + eps) x + out)

        def message(self, x_j):
            return x_j

        def message_and_aggregate(self, adj_t, x):
            return torch_sparse.matmul(adj_t, x, reduce=self.aggr)

Playing around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box.
To convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform:

.. code-block:: python

    import torch
    import torch.nn.functional as F

    from torch_geometric.nn import GCNConv
    import torch_geometric.transforms as T
    from torch_geometric.datasets import Planetoid

    dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor())
    data = dataset[0]
    >>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)


    class GNN(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = GCNConv(dataset.num_features, 16, cached=True)
            self.conv2 = GCNConv(16, dataset.num_classes, cached=True)

        def forward(self, x, adj_t):
            x = self.conv1(x, adj_t)
            x = F.relu(x)
            x = self.conv2(x, adj_t)
            return F.log_softmax(x, dim=1)

    model = GNN()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    def train(data):
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.adj_t)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        return float(loss)

    for epoch in range(1, 201):
        loss = train(data)

All code remains the same as before, except for the :obj:`data` transform via :obj:`T.ToSparseTensor()`.
As an additional advantage, :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementations that utilize the :class:`SparseTensor` class are deterministic on the GPU since aggregations no longer rely on atomic operations.

Notably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:`edge_weight` or :obj:`edge_attr` into their message passing formulation, respectively.
In particular, it is now expected that these attributes are directly added as values to the :class:`SparseTensor` object.
Instead of calling the GNN as

.. code-block:: python

    conv = GMMConv(16, 32, dim=3)
    out = conv(x, edge_index, edge_attr)

we now execute our GNN operator as

.. code-block:: python

    conv = GMMConv(16, 32, dim=3)
    adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr)
    out = conv(x, adj.t())

.. note::

    Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format.
    You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via:

    .. code-block:: python

        row, col, edge_attr = adj_t.t().coo()
        edge_index = torch.stack([row, col], dim=0)

Please let us know what you think of :class:`SparseTensor`, how we can improve it, and whenever you encounter any unexpected behavior.