1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
|
Advanced Mini-Batching
======================
The creation of mini-batching is crucial for letting the training of a deep learning model scale to huge amounts of data.
Instead of processing examples one-by-one, a mini-batch groups a set of examples into a unified representation where it can efficiently be processed in parallel.
In the image or language domain, this procedure is typically achieved by rescaling or padding each example into a set to equally-sized shapes, and examples are then grouped in an additional dimension.
The length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the :obj:`batch_size`.
Since graphs are one of the most general data structures that can hold *any* number of nodes or edges, the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption.
In :pyg:`PyG`, we opt for another approach to achieve parallelization across a number of examples.
Here, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension, *i.e.*
.. math::
\mathbf{A} = \begin{bmatrix} \mathbf{A}_1 & & \\ & \ddots & \\ & & \mathbf{A}_n \end{bmatrix}, \qquad \mathbf{X} = \begin{bmatrix} \mathbf{X}_1 \\ \vdots \\ \mathbf{X}_n \end{bmatrix}, \qquad \mathbf{Y} = \begin{bmatrix} \mathbf{Y}_1 \\ \vdots \\ \mathbf{Y}_n \end{bmatrix}.
This procedure has some crucial advantages over other batching procedures:
1. GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.
2. There is no computational or memory overhead.
For example, this batching procedure works completely without any padding of node or edge features.
Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges.
:pyg:`PyG` automatically takes care of batching multiple graphs into a single giant graph with the help of the :class:`torch_geometric.loader.DataLoader` class.
Internally, :class:`~torch_geometric.loader.DataLoader` is just a regular :pytorch:`PyTorch` :class:`torch.utils.data.DataLoader` that overwrites its :func:`collate` functionality, *i.e.*, the definition of how a list of examples should be grouped together.
Therefore, all arguments that can be passed to a :pytorch:`PyTorch` :class:`~torch.utils.data.DataLoader` can also be passed to a :pyg:`PyG` :class:`~torch_geometric.loader.DataLoader`, *e.g.*, the number of workers :obj:`num_workers`.
In its most general form, the :pyg:`PyG` :class:`~torch_geometric.loader.DataLoader` will automatically increment the :obj:`edge_index` tensor by the cumulated number of nodes of all graphs that got collated before the currently processed graph, and will concatenate :obj:`edge_index` tensors (that are of shape :obj:`[2, num_edges]`) in the second dimension.
The same is true for :obj:`face` tensors, *i.e.*, face indices in meshes.
All other tensors will just get concatenated in the first dimension without any further increasement of their values.
However, there are a few special use-cases (as outlined below) where the user actively wants to modify this behavior to its own needs.
:pyg:`PyG` allows modification to the underlying batching procedure by overwriting the :meth:`torch_geometric.data.Data.__inc__` and :meth:`torch_geometric.data.Data.__cat_dim__` functionalities.
Without any modifications, these are defined as follows in the :class:`~torch_geometric.data.Data` class:
.. code-block:: python
def __inc__(self, key, value, *args, **kwargs):
if 'index' in key:
return self.num_nodes
else:
return 0
def __cat_dim__(self, key, value, *args, **kwargs):
if 'index' in key:
return 1
else:
return 0
We can see that :meth:`~torch_geometric.data.Data.__inc__` defines the incremental count between two consecutive graph attributes.
By default, :pyg:`PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index` (for historical reasons), which comes in handy for attributes such as :obj:`edge_index` or :obj:`node_index`.
However, note that this may lead to unexpected behavior for attributes whose names contain the substring :obj:`index` but should not be incremented.
To make sure, it is best practice to always double-check the output of batching.
Furthermore, :meth:`~torch_geometric.data.Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together.
Both functions are called for each attribute stored in the :class:`~torch_geometric.data.Data` class, and get passed their specific :obj:`key` and value :obj:`item` as arguments.
In what follows, we present a few use-cases where the modification of :meth:`~torch_geometric.data.Data.__inc__` and :meth:`~torch_geometric.data.Data.__cat_dim__` might be absolutely necessary.
Pairs of Graphs
---------------
In case you want to store multiple graphs in a single :class:`~torch_geometric.data.Data` object, *e.g.*, for applications such as graph matching, you need to ensure correct batching behavior across all those graphs.
For example, consider storing two graphs, a source graph :math:`\mathcal{G}_s` and a target graph :math:`\mathcal{G}_t` in a :class:`~torch_geometric.data.Data`, *e.g.*:
.. code-block:: python
from torch_geometric.data import Data
class PairData(Data):
pass
data = PairData(x_s=x_s, edge_index_s=edge_index_s, # Source graph.
x_t=x_t, edge_index_t=edge_index_t) # Target graph.
In this case, :obj:`edge_index_s` should be increased by the number of nodes in the source graph :math:`\mathcal{G}_s`, *e.g.*, :obj:`x_s.size(0)`, and :obj:`edge_index_t` should be increased by the number of nodes in the target graph :math:`\mathcal{G}_t`, *e.g.*, :obj:`x_t.size(0)`:
.. code-block:: python
class PairData(Data):
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index_s':
return self.x_s.size(0)
if key == 'edge_index_t':
return self.x_t.size(0)
return super().__inc__(key, value, *args, **kwargs)
We can test our :class:`PairData` batching behavior by setting up a simple test script:
.. code-block:: python
from torch_geometric.loader import DataLoader
x_s = torch.randn(5, 16) # 5 nodes.
edge_index_s = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
])
x_t = torch.randn(4, 16) # 4 nodes.
edge_index_t = torch.tensor([
[0, 0, 0],
[1, 2, 3],
])
data = PairData(x_s=x_s, edge_index_s=edge_index_s,
x_t=x_t, edge_index_t=edge_index_t)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8],
x_t=[8, 16], edge_index_t=[2, 6])
print(batch.edge_index_s)
>>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],
[1, 2, 3, 4, 6, 7, 8, 9]])
print(batch.edge_index_t)
>>> tensor([[0, 0, 0, 4, 4, 4],
[1, 2, 3, 5, 6, 7]])
Everything looks good so far!
:obj:`edge_index_s` and :obj:`edge_index_t` get correctly batched together, even when using a different numbers of nodes for :math:`\mathcal{G}_s` and :math:`\mathcal{G}_t`.
However, the :obj:`batch` attribute (that maps each node to its respective graph) is missing since :pyg:`PyG` fails to identify the actual graph in the :class:`PairData` object.
That is where the :obj:`follow_batch` argument of the :class:`~torch_geometric.loader.DataLoader` comes into play.
Here, we can specify for which attributes we want to maintain the batch information:
.. code-block:: python
loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])
batch = next(iter(loader))
print(batch)
>>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_s_batch=[10],
x_t=[8, 16], edge_index_t=[2, 6], x_t_batch=[8])
print(batch.x_s_batch)
>>> tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
print(batch.x_t_batch)
>>> tensor([0, 0, 0, 0, 1, 1, 1, 1])
As one can see, :obj:`follow_batch=['x_s', 'x_t']` now successfully creates assignment vectors :obj:`x_s_batch` and :obj:`x_t_batch` for the node features :obj:`x_s` and :obj:`x_t`, respectively.
That information can now be used to perform reduce operations, *e.g.*, global pooling, on multiple graphs in a single :class:`Batch` object.
Bipartite Graphs
----------------
The adjacency matrix of a bipartite graph defines the relationship between nodes of two different node types.
In general, the number of nodes for each node type do not need to match, resulting in a non-quadratic adjacency matrix of shape :math:`\mathbf{A} \in \{ 0, 1 \}^{N \times M}` with :math:`N \neq M` potentially.
In a mini-batching procedure of bipartite graphs, the source nodes of edges in :obj:`edge_index` should get increased differently than the target nodes of edges in :obj:`edge_index`.
To achieve this, consider a bipartite graph between two node types with corresponding node features :obj:`x_s` and :obj:`x_t`, respectively:
.. code-block:: python
from torch_geometric.data import Data
class BipartiteData(Data):
pass
data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)
For a correct mini-batching procedure in bipartite graphs, we need to tell :pyg:`PyG` that it should increment source and target nodes of edges in :obj:`edge_index` independently:
.. code-block:: python
class BipartiteData(Data):
def __inc__(self, key, value, *args, **kwargs):
if key == 'edge_index':
return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])
return super().__inc__(key, value, *args, **kwargs)
Here, :obj:`edge_index[0]` (the source nodes of edges) get incremented by :obj:`x_s.size(0)` while :obj:`edge_index[1]` (the target nodes of edges) get incremented by :obj:`x_t.size(0)`.
We can again test our implementation by running a simple test script:
.. code-block:: python
from torch_geometric.loader import DataLoader
x_s = torch.randn(2, 16) # 2 nodes.
x_t = torch.randn(3, 16) # 3 nodes.
edge_index = torch.tensor([
[0, 0, 1, 1],
[0, 1, 1, 2],
])
data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> BipartiteDataBatch(x_s=[4, 16], x_t=[6, 16], edge_index=[2, 8])
print(batch.edge_index)
>>> tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[0, 1, 1, 2, 3, 4, 4, 5]])
Again, this is exactly the behavior we aimed for!
Batching Along New Dimensions
-----------------------------
Sometimes, attributes of :obj:`data` objects should be batched by gaining a new batch dimension (as in classical mini-batching), *e.g.*, for graph-level properties or targets.
Specifically, a list of attributes of shape :obj:`[num_features]` should be returned as :obj:`[num_examples, num_features]` rather than :obj:`[num_examples * num_features]`.
:pyg:`PyG` achieves this by returning a concatenation dimension of :obj:`None` in :meth:`~torch_geometric.data.Data.__cat_dim__`:
.. code-block:: python
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
class MyData(Data):
def __cat_dim__(self, key, value, *args, **kwargs):
if key == 'foo':
return None
return super().__cat_dim__(key, value, *args, **kwargs)
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1],
])
foo = torch.randn(16)
data = MyData(num_nodes=3, edge_index=edge_index, foo=foo)
data_list = [data, data]
loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)
>>> MyDataBatch(num_nodes=6, edge_index=[2, 8], foo=[2, 16])
As desired, :obj:`batch.foo` is now described by two dimensions: The batch dimension and the feature dimension.
|