1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
|
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch_sparse.tensor import SparseTensor
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None:
value = value.to(other.dtype)
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_sum(row, rowptr, col, value, colptr,
csr2csc, other)
def spmm_add(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
return spmm_sum(src, other)
def spmm_mean(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
rowptr, col, value = src.csr()
row = src.storage._row
rowcount = src.storage._rowcount
csr2csc = src.storage._csr2csc
colptr = src.storage._colptr
if value is not None:
value = value.to(other.dtype)
if value is not None and value.requires_grad:
row = src.storage.row()
if other.requires_grad:
row = src.storage.row()
rowcount = src.storage.rowcount()
csr2csc = src.storage.csr2csc()
colptr = src.storage.colptr()
return torch.ops.torch_sparse.spmm_mean(row, rowptr, col, value, rowcount,
colptr, csr2csc, other)
def spmm_min(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
if value is not None:
value = value.to(other.dtype)
return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)
def spmm_max(src: SparseTensor,
other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
if value is not None:
value = value.to(other.dtype)
return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)
def spmm(src: SparseTensor,
other: torch.Tensor,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return spmm_sum(src, other)
elif reduce == 'mean':
return spmm_mean(src, other)
elif reduce == 'min':
return spmm_min(src, other)[0]
elif reduce == 'max':
return spmm_max(src, other)[0]
else:
raise ValueError
def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
A = src.to_torch_sparse_coo_tensor()
B = other.to_torch_sparse_coo_tensor()
C = torch.sparse.mm(A, B)
edge_index = C._indices()
row, col = edge_index[0], edge_index[1]
value: Optional[Tensor] = None
if src.has_value() or other.has_value():
value = C._values()
return SparseTensor(
row=row,
col=col,
value=value,
sparse_sizes=(C.size(0), C.size(1)),
is_sorted=True,
trust_data=True,
)
def spspmm_add(src: SparseTensor, other: SparseTensor) -> SparseTensor:
return spspmm_sum(src, other)
def spspmm(src: SparseTensor,
other: SparseTensor,
reduce: str = "sum") -> SparseTensor:
if reduce == 'sum' or reduce == 'add':
return spspmm_sum(src, other)
elif reduce == 'mean' or reduce == 'min' or reduce == 'max':
raise NotImplementedError
else:
raise ValueError
@torch.jit._overload # noqa: F811
def matmul(src, other, reduce): # noqa: F811
# type: (SparseTensor, torch.Tensor, str) -> torch.Tensor
pass
@torch.jit._overload # noqa: F811
def matmul(src, other, reduce): # noqa: F811
# type: (SparseTensor, SparseTensor, str) -> SparseTensor
pass
def matmul(src, other, reduce="sum"): # noqa: F811
"""Matrix product of a sparse tensor with either another sparse tensor or a
dense tensor. The sparse tensor represents an adjacency matrix and is
stored as a list of edges. This method multiplies elements along the rows
of the adjacency matrix with the column of the other matrix. In regular
matrix multiplication, the products are then summed together, but this
method allows us to use other aggregation functions as well.
Args:
src (:class:`SparseTensor`): The sparse tensor.
other (:class:`Tensor` or :class:`SparseTensor`): The second matrix.
reduce (string, optional): The function to reduce along the rows of
:obj:`src` and columns of :obj:`other`. Can be :obj:`"sum"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`.
(default: :obj:`"sum"`)
:rtype: (:class:`Tensor`)
"""
if isinstance(other, torch.Tensor):
return spmm(src, other, reduce)
elif isinstance(other, SparseTensor):
return spspmm(src, other, reduce)
raise ValueError
SparseTensor.spmm = lambda self, other, reduce="sum": spmm(self, other, reduce)
SparseTensor.spspmm = lambda self, other, reduce="sum": spspmm(
self, other, reduce)
SparseTensor.matmul = lambda self, other, reduce="sum": matmul(
self, other, reduce)
SparseTensor.__matmul__ = lambda self, other: matmul(self, other, 'sum')
|