File: storage.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (801 lines) | stat: -rw-r--r-- 25,090 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
import warnings
from typing import List, Optional, Tuple

import torch
from torch_scatter import scatter_add, segment_csr

from torch_sparse.utils import Final, index_sort

layouts: Final[List[str]] = ['coo', 'csr', 'csc']


def get_layout(layout: Optional[str] = None) -> str:
    if layout is None:
        layout = 'coo'
        warnings.warn('`layout` argument unset, using default layout '
                      '"coo". This may lead to unexpected behaviour.')
    assert layout == 'coo' or layout == 'csr' or layout == 'csc'
    return layout


@torch.jit.script
class SparseStorage(object):
    _row: Optional[torch.Tensor]
    _rowptr: Optional[torch.Tensor]
    _col: torch.Tensor
    _value: Optional[torch.Tensor]
    _sparse_sizes: Tuple[int, int]
    _rowcount: Optional[torch.Tensor]
    _colptr: Optional[torch.Tensor]
    _colcount: Optional[torch.Tensor]
    _csr2csc: Optional[torch.Tensor]
    _csc2csr: Optional[torch.Tensor]

    def __init__(
        self,
        row: Optional[torch.Tensor] = None,
        rowptr: Optional[torch.Tensor] = None,
        col: Optional[torch.Tensor] = None,
        value: Optional[torch.Tensor] = None,
        sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
        rowcount: Optional[torch.Tensor] = None,
        colptr: Optional[torch.Tensor] = None,
        colcount: Optional[torch.Tensor] = None,
        csr2csc: Optional[torch.Tensor] = None,
        csc2csr: Optional[torch.Tensor] = None,
        is_sorted: bool = False,
        trust_data: bool = False,
    ):

        assert row is not None or rowptr is not None
        assert col is not None
        assert col.dtype == torch.long
        assert col.dim() == 1
        col = col.contiguous()

        M: int = 0
        if sparse_sizes is None or sparse_sizes[0] is None:
            if rowptr is not None:
                M = rowptr.numel() - 1
            elif row is not None and row.numel() > 0:
                M = int(row.max()) + 1
        else:
            _M = sparse_sizes[0]
            assert _M is not None
            M = _M
            if rowptr is not None:
                assert rowptr.numel() - 1 == M
            elif row is not None and row.numel() > 0:
                assert trust_data or int(row.max()) < M

        N: int = 0
        if sparse_sizes is None or sparse_sizes[1] is None:
            if col.numel() > 0:
                N = int(col.max()) + 1
        else:
            _N = sparse_sizes[1]
            assert _N is not None
            N = _N
            if col.numel() > 0:
                assert trust_data or int(col.max()) < N

        sparse_sizes = (M, N)

        if row is not None:
            assert row.dtype == torch.long
            assert row.device == col.device
            assert row.dim() == 1
            assert row.numel() == col.numel()
            row = row.contiguous()

        if rowptr is not None:
            assert rowptr.dtype == torch.long
            assert rowptr.device == col.device
            assert rowptr.dim() == 1
            assert rowptr.numel() - 1 == sparse_sizes[0]
            rowptr = rowptr.contiguous()

        if value is not None:
            assert value.device == col.device
            assert value.size(0) == col.size(0)
            value = value.contiguous()

        if rowcount is not None:
            assert rowcount.dtype == torch.long
            assert rowcount.device == col.device
            assert rowcount.dim() == 1
            assert rowcount.numel() == sparse_sizes[0]
            rowcount = rowcount.contiguous()

        if colptr is not None:
            assert colptr.dtype == torch.long
            assert colptr.device == col.device
            assert colptr.dim() == 1
            assert colptr.numel() - 1 == sparse_sizes[1]
            colptr = colptr.contiguous()

        if colcount is not None:
            assert colcount.dtype == torch.long
            assert colcount.device == col.device
            assert colcount.dim() == 1
            assert colcount.numel() == sparse_sizes[1]
            colcount = colcount.contiguous()

        if csr2csc is not None:
            assert csr2csc.dtype == torch.long
            assert csr2csc.device == col.device
            assert csr2csc.dim() == 1
            assert csr2csc.numel() == col.size(0)
            csr2csc = csr2csc.contiguous()

        if csc2csr is not None:
            assert csc2csr.dtype == torch.long
            assert csc2csr.device == col.device
            assert csc2csr.dim() == 1
            assert csc2csr.numel() == col.size(0)
            csc2csr = csc2csr.contiguous()

        self._row = row
        self._rowptr = rowptr
        self._col = col
        self._value = value
        self._sparse_sizes = tuple(sparse_sizes)
        self._rowcount = rowcount
        self._colptr = colptr
        self._colcount = colcount
        self._csr2csc = csr2csc
        self._csc2csr = csc2csr

        if not is_sorted and self._col.numel() > 0:
            idx = self._col.new_zeros(self._col.numel() + 1)
            idx[1:] = self.row()
            idx[1:] *= self._sparse_sizes[1]
            idx[1:] += self._col
            if (idx[1:] < idx[:-1]).any():
                max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
                _, perm = index_sort(idx[1:], max_value)
                self._row = self.row()[perm]
                self._col = self._col[perm]
                if value is not None:
                    self._value = value[perm]
                self._csr2csc = None
                self._csc2csr = None

    @classmethod
    def empty(self):
        row = torch.tensor([], dtype=torch.long)
        col = torch.tensor([], dtype=torch.long)
        return SparseStorage(
            row=row,
            rowptr=None,
            col=col,
            value=None,
            sparse_sizes=(0, 0),
            rowcount=None,
            colptr=None,
            colcount=None,
            csr2csc=None,
            csc2csr=None,
            is_sorted=True,
            trust_data=True,
        )

    def has_row(self) -> bool:
        return self._row is not None

    def row(self):
        row = self._row
        if row is not None:
            return row

        rowptr = self._rowptr
        if rowptr is not None:
            row = torch.ops.torch_sparse.ptr2ind(rowptr, self._col.numel())
            self._row = row
            return row

        raise ValueError

    def has_rowptr(self) -> bool:
        return self._rowptr is not None

    def rowptr(self) -> torch.Tensor:
        rowptr = self._rowptr
        if rowptr is not None:
            return rowptr

        row = self._row
        if row is not None:
            rowptr = torch.ops.torch_sparse.ind2ptr(row, self._sparse_sizes[0])
            self._rowptr = rowptr
            return rowptr

        raise ValueError

    def col(self) -> torch.Tensor:
        return self._col

    def has_value(self) -> bool:
        return self._value is not None

    def value(self) -> Optional[torch.Tensor]:
        return self._value

    def set_value_(
        self,
        value: Optional[torch.Tensor],
        layout: Optional[str] = None,
    ):
        if value is not None:
            if get_layout(layout) == 'csc':
                value = value[self.csc2csr()]
            value = value.contiguous()
            assert value.device == self._col.device
            assert value.size(0) == self._col.numel()

        self._value = value
        return self

    def set_value(
        self,
        value: Optional[torch.Tensor],
        layout: Optional[str] = None,
    ):
        if value is not None:
            if get_layout(layout) == 'csc':
                value = value[self.csc2csr()]
            value = value.contiguous()
            assert value.device == self._col.device
            assert value.size(0) == self._col.numel()

        return SparseStorage(
            row=self._row,
            rowptr=self._rowptr,
            col=self._col,
            value=value,
            sparse_sizes=self._sparse_sizes,
            rowcount=self._rowcount,
            colptr=self._colptr,
            colcount=self._colcount,
            csr2csc=self._csr2csc,
            csc2csr=self._csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def sparse_sizes(self) -> Tuple[int, int]:
        return self._sparse_sizes

    def sparse_size(self, dim: int) -> int:
        return self._sparse_sizes[dim]

    def sparse_resize(self, sparse_sizes: Tuple[int, int]):
        assert len(sparse_sizes) == 2
        old_sparse_sizes, nnz = self._sparse_sizes, self._col.numel()

        diff_0 = sparse_sizes[0] - old_sparse_sizes[0]
        rowcount, rowptr = self._rowcount, self._rowptr
        if diff_0 > 0:
            if rowptr is not None:
                rowptr = torch.cat([rowptr, rowptr.new_full((diff_0, ), nnz)])
            if rowcount is not None:
                rowcount = torch.cat([rowcount, rowcount.new_zeros(diff_0)])
        elif diff_0 < 0:
            if rowptr is not None:
                rowptr = rowptr[:diff_0]
            if rowcount is not None:
                rowcount = rowcount[:diff_0]

        diff_1 = sparse_sizes[1] - old_sparse_sizes[1]
        colcount, colptr = self._colcount, self._colptr
        if diff_1 > 0:
            if colptr is not None:
                colptr = torch.cat([colptr, colptr.new_full((diff_1, ), nnz)])
            if colcount is not None:
                colcount = torch.cat([colcount, colcount.new_zeros(diff_1)])
        elif diff_1 < 0:
            if colptr is not None:
                colptr = colptr[:diff_1]
            if colcount is not None:
                colcount = colcount[:diff_1]

        return SparseStorage(
            row=self._row,
            rowptr=rowptr,
            col=self._col,
            value=self._value,
            sparse_sizes=sparse_sizes,
            rowcount=rowcount,
            colptr=colptr,
            colcount=colcount,
            csr2csc=self._csr2csc,
            csc2csr=self._csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def sparse_reshape(self, num_rows: int, num_cols: int):
        assert num_rows > 0 or num_rows == -1
        assert num_cols > 0 or num_cols == -1
        assert num_rows > 0 or num_cols > 0

        total = self.sparse_size(0) * self.sparse_size(1)

        if num_rows == -1:
            num_rows = total // num_cols

        if num_cols == -1:
            num_cols = total // num_rows

        assert num_rows * num_cols == total

        idx = self.sparse_size(1) * self.row() + self.col()

        row = torch.div(idx, num_cols, rounding_mode='floor')
        col = idx % num_cols
        assert row.dtype == torch.long and col.dtype == torch.long

        return SparseStorage(
            row=row,
            rowptr=None,
            col=col,
            value=self._value,
            sparse_sizes=(num_rows, num_cols),
            rowcount=None,
            colptr=None,
            colcount=None,
            csr2csc=None,
            csc2csr=None,
            is_sorted=True,
            trust_data=True,
        )

    def has_rowcount(self) -> bool:
        return self._rowcount is not None

    def rowcount(self) -> torch.Tensor:
        rowcount = self._rowcount
        if rowcount is not None:
            return rowcount

        rowptr = self.rowptr()
        rowcount = rowptr[1:] - rowptr[:-1]
        self._rowcount = rowcount
        return rowcount

    def has_colptr(self) -> bool:
        return self._colptr is not None

    def colptr(self) -> torch.Tensor:
        colptr = self._colptr
        if colptr is not None:
            return colptr

        csr2csc = self._csr2csc
        if csr2csc is not None:
            colptr = torch.ops.torch_sparse.ind2ptr(self._col[csr2csc],
                                                    self._sparse_sizes[1])
        else:
            colptr = self._col.new_zeros(self._sparse_sizes[1] + 1)
            torch.cumsum(self.colcount(), dim=0, out=colptr[1:])
        self._colptr = colptr
        return colptr

    def has_colcount(self) -> bool:
        return self._colcount is not None

    def colcount(self) -> torch.Tensor:
        colcount = self._colcount
        if colcount is not None:
            return colcount

        colptr = self._colptr
        if colptr is not None:
            colcount = colptr[1:] - colptr[:-1]
        else:
            colcount = scatter_add(
                torch.ones_like(self._col),
                self._col,
                dim_size=self._sparse_sizes[1],
            )
        self._colcount = colcount
        return colcount

    def has_csr2csc(self) -> bool:
        return self._csr2csc is not None

    def csr2csc(self) -> torch.Tensor:
        csr2csc = self._csr2csc
        if csr2csc is not None:
            return csr2csc

        idx = self._sparse_sizes[0] * self._col + self.row()
        max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
        _, csr2csc = index_sort(idx, max_value)
        self._csr2csc = csr2csc
        return csr2csc

    def has_csc2csr(self) -> bool:
        return self._csc2csr is not None

    def csc2csr(self) -> torch.Tensor:
        csc2csr = self._csc2csr
        if csc2csr is not None:
            return csc2csr

        max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
        _, csc2csr = index_sort(self.csr2csc(), max_value)
        self._csc2csr = csc2csr
        return csc2csr

    def is_coalesced(self) -> bool:
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_sizes[1] * self.row() + self._col
        return bool((idx[1:] > idx[:-1]).all())

    def coalesce(self, reduce: str = "add"):
        idx = self._col.new_full((self._col.numel() + 1, ), -1)
        idx[1:] = self._sparse_sizes[1] * self.row() + self._col
        mask = idx[1:] > idx[:-1]

        if mask.all():  # Skip if indices are already coalesced.
            return self

        row = self.row()[mask]
        col = self._col[mask]

        value = self._value
        if value is not None:
            ptr = mask.nonzero().flatten()
            ptr = torch.cat([ptr, ptr.new_full((1, ), value.size(0))])
            value = segment_csr(value, ptr, reduce=reduce)

        return SparseStorage(
            row=row,
            rowptr=None,
            col=col,
            value=value,
            sparse_sizes=self._sparse_sizes,
            rowcount=None,
            colptr=None,
            colcount=None,
            csr2csc=None,
            csc2csr=None,
            is_sorted=True,
            trust_data=True,
        )

    def fill_cache_(self):
        self.row()
        self.rowptr()
        self.rowcount()
        self.colptr()
        self.colcount()
        self.csr2csc()
        self.csc2csr()
        return self

    def clear_cache_(self):
        self._rowcount = None
        self._colptr = None
        self._colcount = None
        self._csr2csc = None
        self._csc2csr = None
        return self

    def cached_keys(self) -> List[str]:
        keys: List[str] = []
        if self.has_rowcount():
            keys.append('rowcount')
        if self.has_colptr():
            keys.append('colptr')
        if self.has_colcount():
            keys.append('colcount')
        if self.has_csr2csc():
            keys.append('csr2csc')
        if self.has_csc2csr():
            keys.append('csc2csr')
        return keys

    def num_cached_keys(self) -> int:
        return len(self.cached_keys())

    def copy(self):
        return SparseStorage(
            row=self._row,
            rowptr=self._rowptr,
            col=self._col,
            value=self._value,
            sparse_sizes=self._sparse_sizes,
            rowcount=self._rowcount,
            colptr=self._colptr,
            colcount=self._colcount,
            csr2csc=self._csr2csc,
            csc2csr=self._csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def clone(self):
        row = self._row
        if row is not None:
            row = row.clone()
        rowptr = self._rowptr
        if rowptr is not None:
            rowptr = rowptr.clone()
        col = self._col.clone()
        value = self._value
        if value is not None:
            value = value.clone()
        rowcount = self._rowcount
        if rowcount is not None:
            rowcount = rowcount.clone()
        colptr = self._colptr
        if colptr is not None:
            colptr = colptr.clone()
        colcount = self._colcount
        if colcount is not None:
            colcount = colcount.clone()
        csr2csc = self._csr2csc
        if csr2csc is not None:
            csr2csc = csr2csc.clone()
        csc2csr = self._csc2csr
        if csc2csr is not None:
            csc2csr = csc2csr.clone()

        return SparseStorage(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=self._sparse_sizes,
            rowcount=rowcount,
            colptr=colptr,
            colcount=colcount,
            csr2csc=csr2csc,
            csc2csr=csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def type(self, dtype: torch.dtype, non_blocking: bool = False):
        value = self._value
        if value is not None:
            if dtype == value.dtype:
                return self
            else:
                return self.set_value(
                    value.to(dtype=dtype, non_blocking=non_blocking),
                    layout='coo',
                )
        else:
            return self

    def type_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        return self.type(dtype=tensor.dtype, non_blocking=non_blocking)

    def to_device(self, device: torch.device, non_blocking: bool = False):
        if device == self._col.device:
            return self

        row = self._row
        if row is not None:
            row = row.to(device, non_blocking=non_blocking)
        rowptr = self._rowptr
        if rowptr is not None:
            rowptr = rowptr.to(device, non_blocking=non_blocking)
        col = self._col.to(device, non_blocking=non_blocking)
        value = self._value
        if value is not None:
            value = value.to(device, non_blocking=non_blocking)
        rowcount = self._rowcount
        if rowcount is not None:
            rowcount = rowcount.to(device, non_blocking=non_blocking)
        colptr = self._colptr
        if colptr is not None:
            colptr = colptr.to(device, non_blocking=non_blocking)
        colcount = self._colcount
        if colcount is not None:
            colcount = colcount.to(device, non_blocking=non_blocking)
        csr2csc = self._csr2csc
        if csr2csc is not None:
            csr2csc = csr2csc.to(device, non_blocking=non_blocking)
        csc2csr = self._csc2csr
        if csc2csr is not None:
            csc2csr = csc2csr.to(device, non_blocking=non_blocking)

        return SparseStorage(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=self._sparse_sizes,
            rowcount=rowcount,
            colptr=colptr,
            colcount=colcount,
            csr2csc=csr2csc,
            csc2csr=csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def device_as(self, tensor: torch.Tensor, non_blocking: bool = False):
        return self.to_device(device=tensor.device, non_blocking=non_blocking)

    def cuda(self):
        new_col = self._col.cuda()
        if new_col.device == self._col.device:
            return self

        row = self._row
        if row is not None:
            row = row.cuda()
        rowptr = self._rowptr
        if rowptr is not None:
            rowptr = rowptr.cuda()
        value = self._value
        if value is not None:
            value = value.cuda()
        rowcount = self._rowcount
        if rowcount is not None:
            rowcount = rowcount.cuda()
        colptr = self._colptr
        if colptr is not None:
            colptr = colptr.cuda()
        colcount = self._colcount
        if colcount is not None:
            colcount = colcount.cuda()
        csr2csc = self._csr2csc
        if csr2csc is not None:
            csr2csc = csr2csc.cuda()
        csc2csr = self._csc2csr
        if csc2csr is not None:
            csc2csr = csc2csr.cuda()

        return SparseStorage(
            row=row,
            rowptr=rowptr,
            col=new_col,
            value=value,
            sparse_sizes=self._sparse_sizes,
            rowcount=rowcount,
            colptr=colptr,
            colcount=colcount,
            csr2csc=csr2csc,
            csc2csr=csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def pin_memory(self):
        row = self._row
        if row is not None:
            row = row.pin_memory()
        rowptr = self._rowptr
        if rowptr is not None:
            rowptr = rowptr.pin_memory()
        col = self._col.pin_memory()
        value = self._value
        if value is not None:
            value = value.pin_memory()
        rowcount = self._rowcount
        if rowcount is not None:
            rowcount = rowcount.pin_memory()
        colptr = self._colptr
        if colptr is not None:
            colptr = colptr.pin_memory()
        colcount = self._colcount
        if colcount is not None:
            colcount = colcount.pin_memory()
        csr2csc = self._csr2csc
        if csr2csc is not None:
            csr2csc = csr2csc.pin_memory()
        csc2csr = self._csc2csr
        if csc2csr is not None:
            csc2csr = csc2csr.pin_memory()

        return SparseStorage(
            row=row,
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=self._sparse_sizes,
            rowcount=rowcount,
            colptr=colptr,
            colcount=colcount,
            csr2csc=csr2csc,
            csc2csr=csc2csr,
            is_sorted=True,
            trust_data=True,
        )

    def is_pinned(self) -> bool:
        is_pinned = True
        row = self._row
        if row is not None:
            is_pinned = is_pinned and row.is_pinned()
        rowptr = self._rowptr
        if rowptr is not None:
            is_pinned = is_pinned and rowptr.is_pinned()
        is_pinned = self._col.is_pinned()
        value = self._value
        if value is not None:
            is_pinned = is_pinned and value.is_pinned()
        rowcount = self._rowcount
        if rowcount is not None:
            is_pinned = is_pinned and rowcount.is_pinned()
        colptr = self._colptr
        if colptr is not None:
            is_pinned = is_pinned and colptr.is_pinned()
        colcount = self._colcount
        if colcount is not None:
            is_pinned = is_pinned and colcount.is_pinned()
        csr2csc = self._csr2csc
        if csr2csc is not None:
            is_pinned = is_pinned and csr2csc.is_pinned()
        csc2csr = self._csc2csr
        if csc2csr is not None:
            is_pinned = is_pinned and csc2csr.is_pinned()
        return is_pinned


def share_memory_(self) -> SparseStorage:
    row = self._row
    if row is not None:
        row.share_memory_()
    rowptr = self._rowptr
    if rowptr is not None:
        rowptr.share_memory_()
    self._col.share_memory_()
    value = self._value
    if value is not None:
        value.share_memory_()
    rowcount = self._rowcount
    if rowcount is not None:
        rowcount.share_memory_()
    colptr = self._colptr
    if colptr is not None:
        colptr.share_memory_()
    colcount = self._colcount
    if colcount is not None:
        colcount.share_memory_()
    csr2csc = self._csr2csc
    if csr2csc is not None:
        csr2csc.share_memory_()
    csc2csr = self._csc2csr
    if csc2csr is not None:
        csc2csr.share_memory_()


def is_shared(self) -> bool:
    is_shared = True
    row = self._row
    if row is not None:
        is_shared = is_shared and row.is_shared()
    rowptr = self._rowptr
    if rowptr is not None:
        is_shared = is_shared and rowptr.is_shared()
    is_shared = is_shared and self._col.is_shared()
    value = self._value
    if value is not None:
        is_shared = is_shared and value.is_shared()
    rowcount = self._rowcount
    if rowcount is not None:
        is_shared = is_shared and rowcount.is_shared()
    colptr = self._colptr
    if colptr is not None:
        is_shared = is_shared and colptr.is_shared()
    colcount = self._colcount
    if colcount is not None:
        is_shared = is_shared and colcount.is_shared()
    csr2csc = self._csr2csc
    if csr2csc is not None:
        is_shared = is_shared and csr2csc.is_shared()
    csc2csr = self._csc2csr
    if csc2csr is not None:
        is_shared = is_shared and csc2csr.is_shared()
    return is_shared


SparseStorage.share_memory_ = share_memory_
SparseStorage.is_shared = is_shared