File: masked_select.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (96 lines) | stat: -rw-r--r-- 3,088 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import Optional

import torch
from torch_sparse.storage import SparseStorage, get_layout
from torch_sparse.tensor import SparseTensor


def masked_select(src: SparseTensor, dim: int,
                  mask: torch.Tensor) -> SparseTensor:
    dim = src.dim() + dim if dim < 0 else dim

    assert mask.dim() == 1
    storage = src.storage

    if dim == 0:
        row, col, value = src.coo()
        rowcount = src.storage.rowcount()

        rowcount = rowcount[mask]

        mask = mask[row]
        row = torch.arange(rowcount.size(0),
                           device=row.device).repeat_interleave(rowcount)

        col = col[mask]

        if value is not None:
            value = value[mask]

        sparse_sizes = (rowcount.size(0), src.sparse_size(1))

        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=rowcount,
                                colcount=None, colptr=None, csr2csc=None,
                                csc2csr=None, is_sorted=True)
        return src.from_storage(storage)

    elif dim == 1:
        row, col, value = src.coo()
        csr2csc = src.storage.csr2csc()
        row = row[csr2csc]
        col = col[csr2csc]
        colcount = src.storage.colcount()

        colcount = colcount[mask]

        mask = mask[col]
        col = torch.arange(colcount.size(0),
                           device=col.device).repeat_interleave(colcount)
        row = row[mask]
        csc2csr = (colcount.size(0) * row + col).argsort()
        row, col = row[csc2csr], col[csc2csr]

        if value is not None:
            value = value[csr2csc][mask][csc2csr]

        sparse_sizes = (src.sparse_size(0), colcount.size(0))

        storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
                                sparse_sizes=sparse_sizes, rowcount=None,
                                colcount=colcount, colptr=None, csr2csc=None,
                                csc2csr=csc2csr, is_sorted=True)
        return src.from_storage(storage)

    else:
        value = src.storage.value()
        if value is not None:
            idx = mask.nonzero().flatten()
            return src.set_value(value.index_select(dim - 1, idx),
                                 layout='coo')
        else:
            raise ValueError


def masked_select_nnz(src: SparseTensor, mask: torch.Tensor,
                      layout: Optional[str] = None) -> SparseTensor:
    assert mask.dim() == 1

    if get_layout(layout) == 'csc':
        mask = mask[src.storage.csc2csr()]

    row, col, value = src.coo()
    row, col = row[mask], col[mask]

    if value is not None:
        value = value[mask]

    return SparseTensor(row=row, rowptr=None, col=col, value=value,
                        sparse_sizes=src.sparse_sizes(), is_sorted=True)


SparseTensor.masked_select = lambda self, dim, mask: masked_select(
    self, dim, mask)
tmp = lambda self, mask, layout=None: masked_select_nnz(  # noqa
    self, mask, layout)
SparseTensor.masked_select_nnz = tmp