File: diag.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (118 lines) | stat: -rw-r--r-- 3,937 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Optional

import torch
from torch import Tensor

from torch_sparse.storage import SparseStorage
from torch_sparse.tensor import SparseTensor


def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
    row, col, value = src.coo()
    inv_mask = row != col if k == 0 else row != (col - k)
    new_row, new_col = row[inv_mask], col[inv_mask]

    if value is not None:
        value = value[inv_mask]

    rowcount = src.storage._rowcount
    colcount = src.storage._colcount
    if rowcount is not None or colcount is not None:
        mask = ~inv_mask
        if rowcount is not None:
            rowcount = rowcount.clone()
            rowcount[row[mask]] -= 1
        if colcount is not None:
            colcount = colcount.clone()
            colcount[col[mask]] -= 1

    storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
                            sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
                            colptr=None, colcount=colcount, csr2csc=None,
                            csc2csr=None, is_sorted=True)
    return src.from_storage(storage)


def set_diag(src: SparseTensor, values: Optional[Tensor] = None,
             k: int = 0) -> SparseTensor:
    src = remove_diag(src, k=k)
    row, col, value = src.coo()

    mask = torch.ops.torch_sparse.non_diag_mask(row, col, src.size(0),
                                                src.size(1), k)
    inv_mask = ~mask

    start, num_diag = -k if k < 0 else 0, mask.numel() - row.numel()
    diag = torch.arange(start, start + num_diag, device=row.device)

    new_row = row.new_empty(mask.size(0))
    new_row[mask] = row
    new_row[inv_mask] = diag

    new_col = col.new_empty(mask.size(0))
    new_col[mask] = col
    new_col[inv_mask] = diag.add_(k)

    new_value: Optional[Tensor] = None
    if value is not None:
        new_value = value.new_empty((mask.size(0), ) + value.size()[1:])
        new_value[mask] = value
        if values is not None:
            new_value[inv_mask] = values
        else:
            new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
                                             device=value.device)

    rowcount = src.storage._rowcount
    if rowcount is not None:
        rowcount = rowcount.clone()
        rowcount[start:start + num_diag] += 1

    colcount = src.storage._colcount
    if colcount is not None:
        colcount = colcount.clone()
        colcount[start + k:start + num_diag + k] += 1

    storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
                            value=new_value, sparse_sizes=src.sparse_sizes(),
                            rowcount=rowcount, colptr=None, colcount=colcount,
                            csr2csc=None, csc2csr=None, is_sorted=True)
    return src.from_storage(storage)


def fill_diag(src: SparseTensor, fill_value: float,
              k: int = 0) -> SparseTensor:
    num_diag = min(src.sparse_size(0), src.sparse_size(1) - k)
    if k < 0:
        num_diag = min(src.sparse_size(0) + k, src.sparse_size(1))

    value = src.storage.value()
    if value is not None:
        sizes = [num_diag] + src.sizes()[2:]
        return set_diag(src, value.new_full(sizes, fill_value), k)
    else:
        return set_diag(src, None, k)


def get_diag(src: SparseTensor) -> Tensor:
    row, col, value = src.coo()

    if value is None:
        value = torch.ones(row.size(0), device=row.device)

    sizes = list(value.size())
    sizes[0] = min(src.size(0), src.size(1))

    out = value.new_zeros(sizes)

    mask = row == col
    out[row[mask]] = value[mask]
    return out


SparseTensor.remove_diag = lambda self, k=0: remove_diag(self, k)
SparseTensor.set_diag = lambda self, values=None, k=0: set_diag(
    self, values, k)
SparseTensor.fill_diag = lambda self, fill_value, k=0: fill_diag(
    self, fill_value, k)
SparseTensor.get_diag = lambda self: get_diag(self)