1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
|
import inspect
from collections.abc import Sequence
from typing import Any, List, Optional, Type, Union
import numpy as np
import torch
from torch import Tensor
from typing_extensions import Self
from torch_geometric.data.collate import collate
from torch_geometric.data.data import BaseData, Data
from torch_geometric.data.dataset import IndexType
from torch_geometric.data.separate import separate
class DynamicInheritance(type):
# A meta class that sets the base class of a `Batch` object, e.g.:
# * `Batch(Data)` in case `Data` objects are batched together
# * `Batch(HeteroData)` in case `HeteroData` objects are batched together
def __call__(cls, *args: Any, **kwargs: Any) -> Any:
base_cls = kwargs.pop('_base_cls', Data)
if issubclass(base_cls, Batch):
new_cls = base_cls
else:
name = f'{base_cls.__name__}{cls.__name__}'
# NOTE `MetaResolver` is necessary to resolve metaclass conflict
# problems between `DynamicInheritance` and the metaclass of
# `base_cls`. In particular, it creates a new common metaclass
# from the defined metaclasses.
class MetaResolver(type(cls), type(base_cls)): # type: ignore
pass
if name not in globals():
globals()[name] = MetaResolver(name, (cls, base_cls), {})
new_cls = globals()[name]
params = list(inspect.signature(base_cls.__init__).parameters.items())
for i, (k, v) in enumerate(params[1:]):
if k == 'args' or k == 'kwargs':
continue
if i < len(args) or k in kwargs:
continue
if v.default is not inspect.Parameter.empty:
continue
kwargs[k] = None
return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)
class DynamicInheritanceGetter:
def __call__(self, cls: Type, base_cls: Type) -> Self:
return cls(_base_cls=base_cls)
class Batch(metaclass=DynamicInheritance):
r"""A data object describing a batch of graphs as one big (disconnected)
graph.
Inherits from :class:`torch_geometric.data.Data` or
:class:`torch_geometric.data.HeteroData`.
In addition, single graphs can be identified via the assignment vector
:obj:`batch`, which maps each node to its respective graph identifier.
:pyg:`PyG` allows modification to the underlying batching procedure by
overwriting the :meth:`~Data.__inc__` and :meth:`~Data.__cat_dim__`
functionalities.
The :meth:`~Data.__inc__` method defines the incremental count between two
consecutive graph attributes.
By default, :pyg:`PyG` increments attributes by the number of nodes
whenever their attribute names contain the substring :obj:`index`
(for historical reasons), which comes in handy for attributes such as
:obj:`edge_index` or :obj:`node_index`.
However, note that this may lead to unexpected behavior for attributes
whose names contain the substring :obj:`index` but should not be
incremented.
To make sure, it is best practice to always double-check the output of
batching.
Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph
tensors of the same attribute should be concatenated together.
"""
@classmethod
def from_data_list(
cls,
data_list: List[BaseData],
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
) -> Self:
r"""Constructs a :class:`~torch_geometric.data.Batch` object from a
list of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects.
The assignment vector :obj:`batch` is created on the fly.
In addition, creates assignment vectors for each key in
:obj:`follow_batch`.
Will exclude any keys given in :obj:`exclude_keys`.
"""
batch, slice_dict, inc_dict = collate(
cls,
data_list=data_list,
increment=True,
add_batch=not isinstance(data_list[0], Batch),
follow_batch=follow_batch,
exclude_keys=exclude_keys,
)
batch._num_graphs = len(data_list) # type: ignore
batch._slice_dict = slice_dict # type: ignore
batch._inc_dict = inc_dict # type: ignore
return batch
def get_example(self, idx: int) -> BaseData:
r"""Gets the :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`.
The :class:`~torch_geometric.data.Batch` object must have been created
via :meth:`from_data_list` in order to be able to reconstruct the
initial object.
"""
if not hasattr(self, '_slice_dict'):
raise RuntimeError(
"Cannot reconstruct 'Data' object from 'Batch' because "
"'Batch' was not created via 'Batch.from_data_list()'")
data = separate(
cls=self.__class__.__bases__[-1],
batch=self,
idx=idx,
slice_dict=self._slice_dict,
inc_dict=self._inc_dict,
decrement=True,
)
return data
def index_select(self, idx: IndexType) -> List[BaseData]:
r"""Creates a subset of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from specified
indices :obj:`idx`.
Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a
list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type
long or bool.
The :class:`~torch_geometric.data.Batch` object must have been created
via :meth:`from_data_list` in order to be able to reconstruct the
initial objects.
"""
index: Sequence[int]
if isinstance(idx, slice):
index = list(range(self.num_graphs)[idx])
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
index = idx.flatten().tolist()
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
index = idx.flatten().nonzero(as_tuple=False).flatten().tolist()
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
index = idx.flatten().tolist()
elif isinstance(idx, np.ndarray) and idx.dtype == bool:
index = idx.flatten().nonzero()[0].flatten().tolist()
elif isinstance(idx, Sequence) and not isinstance(idx, str):
index = idx
else:
raise IndexError(
f"Only slices (':'), list, tuples, torch.tensor and "
f"np.ndarray of dtype long or bool are valid indices (got "
f"'{type(idx).__name__}')")
return [self.get_example(i) for i in index]
def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:
if (isinstance(idx, (int, np.integer))
or (isinstance(idx, Tensor) and idx.dim() == 0)
or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
return self.get_example(idx) # type: ignore
elif isinstance(idx, str) or (isinstance(idx, tuple)
and isinstance(idx[0], str)):
# Accessing attributes or node/edge types:
return super().__getitem__(idx) # type: ignore
else:
return self.index_select(idx)
def to_data_list(self) -> List[BaseData]:
r"""Reconstructs the list of :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` objects from the
:class:`~torch_geometric.data.Batch` object.
The :class:`~torch_geometric.data.Batch` object must have been created
via :meth:`from_data_list` in order to be able to reconstruct the
initial objects.
"""
return [self.get_example(i) for i in range(self.num_graphs)]
@property
def num_graphs(self) -> int:
"""Returns the number of graphs in the batch."""
if hasattr(self, '_num_graphs'):
return self._num_graphs
elif hasattr(self, 'ptr'):
return self.ptr.numel() - 1
elif hasattr(self, 'batch'):
return int(self.batch.max()) + 1
else:
raise ValueError("Can not infer the number of graphs")
@property
def batch_size(self) -> int:
r"""Alias for :obj:`num_graphs`."""
return self.num_graphs
def __len__(self) -> int:
return self.num_graphs
def __reduce__(self) -> Any:
state = self.__dict__.copy()
return DynamicInheritanceGetter(), self.__class__.__bases__, state
|