1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
|
from typing import Tuple
from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor
def narrow(src: SparseTensor, dim: int, start: int,
length: int) -> SparseTensor:
if dim < 0:
dim = src.dim() + dim
if start < 0:
start = src.size(dim) + start
if dim == 0:
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start, length=length + 1)
row_start = rowptr[0]
rowptr = rowptr - row_start
row_length = rowptr[-1]
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start
col = col.narrow(0, row_start, row_length)
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = (length, src.sparse_size(1))
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=None, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
elif dim == 1:
# This is faster than accessing `csc()` contrary to the `dim=0` case.
row, col, value = src.coo()
mask = (col >= start) & (col < start + length)
row = row[mask]
col = col[mask] - start
if value is not None:
value = value[mask]
sparse_sizes = (src.sparse_size(0), length)
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start=start, length=length + 1)
colptr = colptr - colptr[0]
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start=start, length=length)
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=colcount, csr2csc=None,
csc2csr=None, is_sorted=True)
return src.from_storage(storage)
else:
value = src.storage.value()
if value is not None:
return src.set_value(value.narrow(dim - 1, start, length),
layout='coo')
else:
raise ValueError
def __narrow_diag__(src: SparseTensor, start: Tuple[int, int],
length: Tuple[int, int]) -> SparseTensor:
# This function builds the inverse operation of `cat_diag` and should hence
# only be used on *diagonally stacked* sparse matrices.
# That's the reason why this method is marked as *private*.
rowptr, col, value = src.csr()
rowptr = rowptr.narrow(0, start=start[0], length=length[0] + 1)
row_start = int(rowptr[0])
rowptr = rowptr - row_start
row_length = int(rowptr[-1])
row = src.storage._row
if row is not None:
row = row.narrow(0, row_start, row_length) - start[0]
col = col.narrow(0, row_start, row_length) - start[1]
if value is not None:
value = value.narrow(0, row_start, row_length)
sparse_sizes = length
rowcount = src.storage._rowcount
if rowcount is not None:
rowcount = rowcount.narrow(0, start[0], length[0])
colptr = src.storage._colptr
if colptr is not None:
colptr = colptr.narrow(0, start[1], length[1] + 1)
colptr = colptr - int(colptr[0]) # i.e. `row_start`
colcount = src.storage._colcount
if colcount is not None:
colcount = colcount.narrow(0, start[1], length[1])
csr2csc = src.storage._csr2csc
if csr2csc is not None:
csr2csc = csr2csc.narrow(0, row_start, row_length) - row_start
csc2csr = src.storage._csc2csr
if csc2csr is not None:
csc2csr = csc2csr.narrow(0, row_start, row_length) - row_start
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=rowcount,
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
csc2csr=csc2csr, is_sorted=True)
return src.from_storage(storage)
SparseTensor.narrow = lambda self, dim, start, length: narrow(
self, dim, start, length)
SparseTensor.__narrow_diag__ = lambda self, start, length: __narrow_diag__(
self, start, length)
|