File: to_sparse_tensor.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (146 lines) | stat: -rw-r--r-- 5,580 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from typing import Optional, Union

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import (
    sort_edge_index,
    to_torch_coo_tensor,
    to_torch_csr_tensor,
)


@functional_transform('to_sparse_tensor')
class ToSparseTensor(BaseTransform):
    r"""Converts the :obj:`edge_index` attributes of a homogeneous or
    heterogeneous data object into a **transposed**
    :class:`torch_sparse.SparseTensor` or :pytorch:`PyTorch`
    :class:`torch.sparse.Tensor` object with key :obj:`adj_t`
    (functional name: :obj:`to_sparse_tensor`).

    .. note::

        In case of composing multiple transforms, it is best to convert the
        :obj:`data` object via :class:`ToSparseTensor` as late as possible,
        since there exist some transforms that are only able to operate on
        :obj:`data.edge_index` for now.

    Args:
        attr (str, optional): The name of the attribute to add as a value to
            the :class:`~torch_sparse.SparseTensor` or
            :class:`torch.sparse.Tensor` object (if present).
            (default: :obj:`edge_weight`)
        remove_edge_index (bool, optional): If set to :obj:`False`, the
            :obj:`edge_index` tensor will not be removed.
            (default: :obj:`True`)
        fill_cache (bool, optional): If set to :obj:`True`, will fill the
            underlying :class:`torch_sparse.SparseTensor` cache (if used).
            (default: :obj:`True`)
        layout (torch.layout, optional): Specifies the layout of the returned
            sparse tensor (:obj:`None`, :obj:`torch.sparse_coo` or
            :obj:`torch.sparse_csr`).
            If set to :obj:`None` and the :obj:`torch_sparse` dependency is
            installed, will convert :obj:`edge_index` into a
            :class:`torch_sparse.SparseTensor` object.
            If set to :obj:`None` and the :obj:`torch_sparse` dependency is
            not installed, will convert :obj:`edge_index` into a
            :class:`torch.sparse.Tensor` object with layout
            :obj:`torch.sparse_csr`. (default: :obj:`None`)
    """
    def __init__(
        self,
        attr: Optional[str] = 'edge_weight',
        remove_edge_index: bool = True,
        fill_cache: bool = True,
        layout: Optional[int] = None,
    ) -> None:
        if layout not in {None, torch.sparse_coo, torch.sparse_csr}:
            raise ValueError(f"Unexpected sparse tensor layout "
                             f"(got '{layout}')")

        self.attr = attr
        self.remove_edge_index = remove_edge_index
        self.fill_cache = fill_cache
        self.layout = layout

    def forward(
        self,
        data: Union[Data, HeteroData],
    ) -> Union[Data, HeteroData]:

        for store in data.edge_stores:
            if 'edge_index' not in store:
                continue

            keys, values = [], []
            for key, value in store.items():
                if key in {'edge_index', 'edge_label', 'edge_label_index'}:
                    continue

                if store.is_edge_attr(key):
                    keys.append(key)
                    values.append(value)

            store.edge_index, values = sort_edge_index(
                store.edge_index,
                values,
                sort_by_row=False,
            )

            for key, value in zip(keys, values):
                store[key] = value

            layout = self.layout
            size = store.size()[::-1]
            edge_weight: Optional[Tensor] = None
            if self.attr is not None and self.attr in store:
                edge_weight = store[self.attr]

            if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:
                store.adj_t = SparseTensor(
                    row=store.edge_index[1],
                    col=store.edge_index[0],
                    value=edge_weight,
                    sparse_sizes=size,
                    is_sorted=True,
                    trust_data=True,
                )

            # TODO Multi-dimensional edge attributes only supported for COO.
            elif ((edge_weight is not None and edge_weight.dim() > 1)
                  or layout == torch.sparse_coo):
                assert size[0] is not None and size[1] is not None
                store.adj_t = to_torch_coo_tensor(
                    store.edge_index.flip([0]),
                    edge_attr=edge_weight,
                    size=size,
                )

            elif layout is None or layout == torch.sparse_csr:
                assert size[0] is not None and size[1] is not None
                store.adj_t = to_torch_csr_tensor(
                    store.edge_index.flip([0]),
                    edge_attr=edge_weight,
                    size=size,
                )

            if self.remove_edge_index:
                del store['edge_index']
                if self.attr is not None and self.attr in store:
                    del store[self.attr]

            if self.fill_cache and isinstance(store.adj_t, SparseTensor):
                # Pre-process some important attributes.
                store.adj_t.storage.rowptr()
                store.adj_t.storage.csr2csc()

        return data

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(attr={self.attr}, '
                f'layout={self.layout})')