File: cat.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (261 lines) | stat: -rw-r--r-- 8,242 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
from typing import Optional, List, Tuple  # noqa

import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor


@torch.jit._overload  # noqa: F811
def cat(tensors, dim):  # noqa: F811
    # type: (List[SparseTensor], int) -> SparseTensor
    pass


@torch.jit._overload  # noqa: F811
def cat(tensors, dim):  # noqa: F811
    # type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
    pass


@torch.jit._overload  # noqa: F811
def cat(tensors, dim):  # noqa: F811
    # type: (List[SparseTensor], List[int]) -> SparseTensor
    pass


def cat(tensors, dim):  # noqa: F811
    assert len(tensors) > 0

    if isinstance(dim, int):
        dim = tensors[0].dim() + dim if dim < 0 else dim

        if dim == 0:
            return cat_first(tensors)

        elif dim == 1:
            return cat_second(tensors)
            pass

        elif dim > 1 and dim < tensors[0].dim():
            values = []
            for tensor in tensors:
                value = tensor.storage.value()
                assert value is not None
                values.append(value)
            value = torch.cat(values, dim=dim - 1)
            return tensors[0].set_value(value, layout='coo')

        else:
            raise IndexError(
                (f'Dimension out of range: Expected to be in range of '
                 f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got '
                 f'{dim}.'))
    else:
        assert isinstance(dim, (tuple, list))
        assert len(dim) == 2
        assert sorted(dim) == [0, 1]
        return cat_diag(tensors)


def cat_first(tensors: List[SparseTensor]) -> SparseTensor:
    rows: List[torch.Tensor] = []
    rowptrs: List[torch.Tensor] = []
    cols: List[torch.Tensor] = []
    values: List[torch.Tensor] = []
    sparse_sizes: List[int] = [0, 0]
    rowcounts: List[torch.Tensor] = []

    nnz: int = 0
    for tensor in tensors:
        row = tensor.storage._row
        if row is not None:
            rows.append(row + sparse_sizes[0])

        rowptr = tensor.storage._rowptr
        if rowptr is not None:
            rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)

        cols.append(tensor.storage._col)

        value = tensor.storage._value
        if value is not None:
            values.append(value)

        rowcount = tensor.storage._rowcount
        if rowcount is not None:
            rowcounts.append(rowcount)

        sparse_sizes[0] += tensor.sparse_size(0)
        sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
        nnz += tensor.nnz()

    row: Optional[torch.Tensor] = None
    if len(rows) == len(tensors):
        row = torch.cat(rows, dim=0)

    rowptr: Optional[torch.Tensor] = None
    if len(rowptrs) == len(tensors):
        rowptr = torch.cat(rowptrs, dim=0)

    col = torch.cat(cols, dim=0)

    value: Optional[torch.Tensor] = None
    if len(values) == len(tensors):
        value = torch.cat(values, dim=0)

    rowcount: Optional[torch.Tensor] = None
    if len(rowcounts) == len(tensors):
        rowcount = torch.cat(rowcounts, dim=0)

    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
                            rowcount=rowcount, colptr=None, colcount=None,
                            csr2csc=None, csc2csr=None, is_sorted=True)
    return tensors[0].from_storage(storage)


def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
    rows: List[torch.Tensor] = []
    cols: List[torch.Tensor] = []
    values: List[torch.Tensor] = []
    sparse_sizes: List[int] = [0, 0]
    colptrs: List[torch.Tensor] = []
    colcounts: List[torch.Tensor] = []

    nnz: int = 0
    for tensor in tensors:
        row, col, value = tensor.coo()
        rows.append(row)
        cols.append(tensor.storage._col + sparse_sizes[1])

        if value is not None:
            values.append(value)

        colptr = tensor.storage._colptr
        if colptr is not None:
            colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)

        colcount = tensor.storage._colcount
        if colcount is not None:
            colcounts.append(colcount)

        sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
        sparse_sizes[1] += tensor.sparse_size(1)
        nnz += tensor.nnz()

    row = torch.cat(rows, dim=0)
    col = torch.cat(cols, dim=0)

    value: Optional[torch.Tensor] = None
    if len(values) == len(tensors):
        value = torch.cat(values, dim=0)

    colptr: Optional[torch.Tensor] = None
    if len(colptrs) == len(tensors):
        colptr = torch.cat(colptrs, dim=0)

    colcount: Optional[torch.Tensor] = None
    if len(colcounts) == len(tensors):
        colcount = torch.cat(colcounts, dim=0)

    storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                            sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
                            rowcount=None, colptr=colptr, colcount=colcount,
                            csr2csc=None, csc2csr=None, is_sorted=False)
    return tensors[0].from_storage(storage)


def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
    assert len(tensors) > 0

    rows: List[torch.Tensor] = []
    rowptrs: List[torch.Tensor] = []
    cols: List[torch.Tensor] = []
    values: List[torch.Tensor] = []
    sparse_sizes: List[int] = [0, 0]
    rowcounts: List[torch.Tensor] = []
    colptrs: List[torch.Tensor] = []
    colcounts: List[torch.Tensor] = []
    csr2cscs: List[torch.Tensor] = []
    csc2csrs: List[torch.Tensor] = []

    nnz: int = 0
    for tensor in tensors:
        row = tensor.storage._row
        if row is not None:
            rows.append(row + sparse_sizes[0])

        rowptr = tensor.storage._rowptr
        if rowptr is not None:
            rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)

        cols.append(tensor.storage._col + sparse_sizes[1])

        value = tensor.storage._value
        if value is not None:
            values.append(value)

        rowcount = tensor.storage._rowcount
        if rowcount is not None:
            rowcounts.append(rowcount)

        colptr = tensor.storage._colptr
        if colptr is not None:
            colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)

        colcount = tensor.storage._colcount
        if colcount is not None:
            colcounts.append(colcount)

        csr2csc = tensor.storage._csr2csc
        if csr2csc is not None:
            csr2cscs.append(csr2csc + nnz)

        csc2csr = tensor.storage._csc2csr
        if csc2csr is not None:
            csc2csrs.append(csc2csr + nnz)

        sparse_sizes[0] += tensor.sparse_size(0)
        sparse_sizes[1] += tensor.sparse_size(1)
        nnz += tensor.nnz()

    row: Optional[torch.Tensor] = None
    if len(rows) == len(tensors):
        row = torch.cat(rows, dim=0)

    rowptr: Optional[torch.Tensor] = None
    if len(rowptrs) == len(tensors):
        rowptr = torch.cat(rowptrs, dim=0)

    col = torch.cat(cols, dim=0)

    value: Optional[torch.Tensor] = None
    if len(values) == len(tensors):
        value = torch.cat(values, dim=0)

    rowcount: Optional[torch.Tensor] = None
    if len(rowcounts) == len(tensors):
        rowcount = torch.cat(rowcounts, dim=0)

    colptr: Optional[torch.Tensor] = None
    if len(colptrs) == len(tensors):
        colptr = torch.cat(colptrs, dim=0)

    colcount: Optional[torch.Tensor] = None
    if len(colcounts) == len(tensors):
        colcount = torch.cat(colcounts, dim=0)

    csr2csc: Optional[torch.Tensor] = None
    if len(csr2cscs) == len(tensors):
        csr2csc = torch.cat(csr2cscs, dim=0)

    csc2csr: Optional[torch.Tensor] = None
    if len(csc2csrs) == len(tensors):
        csc2csr = torch.cat(csc2csrs, dim=0)

    storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
                            sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
                            rowcount=rowcount, colptr=colptr,
                            colcount=colcount, csr2csc=csr2csc,
                            csc2csr=csc2csr, is_sorted=True)
    return tensors[0].from_storage(storage)