1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
from typing import Optional
import torch
from torch import Tensor
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit._overload # noqa: F811
def mul(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def mul(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass
def mul(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
# Col-wise...
elif other.size(0) == 1 and other.size(1) == src.size(1):
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).mul_(value)
else:
value = other
return src.set_value(value, layout='coo')
assert isinstance(other, SparseTensor)
if not src.is_coalesced():
raise ValueError("The `src` tensor is not coalesced")
if not other.is_coalesced():
raise ValueError("The `other` tensor is not coalesced")
rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()
row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)
if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)
else:
raise ValueError('Both sparse tensors must contain values')
M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)
# Sort indices:
idx = col.new_full((col.numel() + 1, ), -1)
idx[1:] = row * sparse_sizes[1] + col
perm = idx[1:].argsort()
idx[1:] = idx[1:][perm]
row, col, value = row[perm], col[perm], value[perm]
valid_mask = idx[1:] == idx[:-1]
valid_idx = valid_mask.nonzero().view(-1)
return SparseTensor(
row=row[valid_mask],
col=col[valid_mask],
value=value[valid_idx - 1] * value[valid_idx],
sparse_sizes=sparse_sizes,
)
def mul_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = value.mul_(other.to(value.dtype))
else:
value = other
return src.set_value_(value, layout='coo')
def mul_nnz(
src: SparseTensor,
other: torch.Tensor,
layout: Optional[str] = None,
) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul(other.to(value.dtype))
else:
value = other
return src.set_value(value, layout=layout)
def mul_nnz_(
src: SparseTensor,
other: torch.Tensor,
layout: Optional[str] = None,
) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.mul_(other.to(value.dtype))
else:
value = other
return src.set_value_(value, layout=layout)
SparseTensor.mul = lambda self, other: mul(self, other)
SparseTensor.mul_ = lambda self, other: mul_(self, other)
SparseTensor.mul_nnz = lambda self, other, layout=None: mul_nnz(
self, other, layout)
SparseTensor.mul_nnz_ = lambda self, other, layout=None: mul_nnz_(
self, other, layout)
SparseTensor.__mul__ = SparseTensor.mul
SparseTensor.__rmul__ = SparseTensor.mul
SparseTensor.__imul__ = SparseTensor.mul_
|