1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
|
from typing import Optional, Union
import torch
from torch import Tensor
import torch_geometric.typing
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import (
sort_edge_index,
to_torch_coo_tensor,
to_torch_csr_tensor,
)
@functional_transform('to_sparse_tensor')
class ToSparseTensor(BaseTransform):
r"""Converts the :obj:`edge_index` attributes of a homogeneous or
heterogeneous data object into a **transposed**
:class:`torch_sparse.SparseTensor` or :pytorch:`PyTorch`
:class:`torch.sparse.Tensor` object with key :obj:`adj_t`
(functional name: :obj:`to_sparse_tensor`).
.. note::
In case of composing multiple transforms, it is best to convert the
:obj:`data` object via :class:`ToSparseTensor` as late as possible,
since there exist some transforms that are only able to operate on
:obj:`data.edge_index` for now.
Args:
attr (str, optional): The name of the attribute to add as a value to
the :class:`~torch_sparse.SparseTensor` or
:class:`torch.sparse.Tensor` object (if present).
(default: :obj:`edge_weight`)
remove_edge_index (bool, optional): If set to :obj:`False`, the
:obj:`edge_index` tensor will not be removed.
(default: :obj:`True`)
fill_cache (bool, optional): If set to :obj:`True`, will fill the
underlying :class:`torch_sparse.SparseTensor` cache (if used).
(default: :obj:`True`)
layout (torch.layout, optional): Specifies the layout of the returned
sparse tensor (:obj:`None`, :obj:`torch.sparse_coo` or
:obj:`torch.sparse_csr`).
If set to :obj:`None` and the :obj:`torch_sparse` dependency is
installed, will convert :obj:`edge_index` into a
:class:`torch_sparse.SparseTensor` object.
If set to :obj:`None` and the :obj:`torch_sparse` dependency is
not installed, will convert :obj:`edge_index` into a
:class:`torch.sparse.Tensor` object with layout
:obj:`torch.sparse_csr`. (default: :obj:`None`)
"""
def __init__(
self,
attr: Optional[str] = 'edge_weight',
remove_edge_index: bool = True,
fill_cache: bool = True,
layout: Optional[int] = None,
) -> None:
if layout not in {None, torch.sparse_coo, torch.sparse_csr}:
raise ValueError(f"Unexpected sparse tensor layout "
f"(got '{layout}')")
self.attr = attr
self.remove_edge_index = remove_edge_index
self.fill_cache = fill_cache
self.layout = layout
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.edge_stores:
if 'edge_index' not in store:
continue
keys, values = [], []
for key, value in store.items():
if key in {'edge_index', 'edge_label', 'edge_label_index'}:
continue
if store.is_edge_attr(key):
keys.append(key)
values.append(value)
store.edge_index, values = sort_edge_index(
store.edge_index,
values,
sort_by_row=False,
)
for key, value in zip(keys, values):
store[key] = value
layout = self.layout
size = store.size()[::-1]
edge_weight: Optional[Tensor] = None
if self.attr is not None and self.attr in store:
edge_weight = store[self.attr]
if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:
store.adj_t = SparseTensor(
row=store.edge_index[1],
col=store.edge_index[0],
value=edge_weight,
sparse_sizes=size,
is_sorted=True,
trust_data=True,
)
# TODO Multi-dimensional edge attributes only supported for COO.
elif ((edge_weight is not None and edge_weight.dim() > 1)
or layout == torch.sparse_coo):
assert size[0] is not None and size[1] is not None
store.adj_t = to_torch_coo_tensor(
store.edge_index.flip([0]),
edge_attr=edge_weight,
size=size,
)
elif layout is None or layout == torch.sparse_csr:
assert size[0] is not None and size[1] is not None
store.adj_t = to_torch_csr_tensor(
store.edge_index.flip([0]),
edge_attr=edge_weight,
size=size,
)
if self.remove_edge_index:
del store['edge_index']
if self.attr is not None and self.attr in store:
del store[self.attr]
if self.fill_cache and isinstance(store.adj_t, SparseTensor):
# Pre-process some important attributes.
store.adj_t.storage.rowptr()
store.adj_t.storage.csr2csc()
return data
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(attr={self.attr}, '
f'layout={self.layout})')
|