1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
|
import torch
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def t(src: SparseTensor) -> SparseTensor:
csr2csc = src.storage.csr2csc()
row, col, value = src.coo()
if value is not None:
value = value[csr2csc]
sparse_sizes = src.storage.sparse_sizes()
storage = SparseStorage(
row=col[csr2csc],
rowptr=src.storage._colptr,
col=row[csr2csc],
value=value,
sparse_sizes=(sparse_sizes[1], sparse_sizes[0]),
rowcount=src.storage._colcount,
colptr=src.storage._rowptr,
colcount=src.storage._rowcount,
csr2csc=src.storage._csc2csr,
csc2csr=csr2csc,
is_sorted=True,
)
return src.from_storage(storage)
SparseTensor.t = lambda self: t(self)
###############################################################################
def transpose(index, value, m, n, coalesced=True):
"""Transposes dimensions 0 and 1 of a sparse tensor.
Args:
index (:class:`LongTensor`): The index tensor of sparse matrix.
value (:class:`Tensor`): The value tensor of sparse matrix.
m (int): The first dimension of sparse matrix.
n (int): The second dimension of sparse matrix.
coalesced (bool, optional): If set to :obj:`False`, will not coalesce
the output. (default: :obj:`True`)
:rtype: (:class:`LongTensor`, :class:`Tensor`)
"""
row, col = index
row, col = col, row
if coalesced:
sparse_sizes = (n, m)
storage = SparseStorage(row=row, col=col, value=value,
sparse_sizes=sparse_sizes, is_sorted=False)
storage = storage.coalesce()
row, col, value = storage.row(), storage.col(), storage.value()
return torch.stack([row, col], dim=0), value
|