File: test_cat.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (48 lines) | stat: -rw-r--r-- 1,769 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pytest
import torch

from torch_sparse.cat import cat
from torch_sparse.tensor import SparseTensor
from torch_sparse.testing import devices, tensor


@pytest.mark.parametrize('device', devices)
def test_cat(device):
    row, col = tensor([[0, 0, 1], [0, 1, 2]], torch.long, device)
    mat1 = SparseTensor(row=row, col=col)
    mat1.fill_cache_()

    row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], torch.long, device)
    mat2 = SparseTensor(row=row, col=col)
    mat2.fill_cache_()

    out = cat([mat1, mat2], dim=0)
    assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0],
                                       [0, 1, 0], [1, 0, 0]]
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.has_rowcount()
    assert out.storage.num_cached_keys() == 1

    out = cat([mat1, mat2], dim=1)
    assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1],
                                       [0, 0, 0, 1, 0]]
    assert out.storage.has_row()
    assert not out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 2

    out = cat([mat1, mat2], dim=(0, 1))
    assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
                                       [0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
                                       [0, 0, 0, 1, 0]]
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 5

    value = torch.randn((mat1.nnz(), 4), device=device)
    mat1 = mat1.set_value_(value, layout='coo')
    out = cat([mat1, mat1], dim=-1)
    assert out.storage.value().size() == (mat1.nnz(), 8)
    assert out.storage.has_row()
    assert out.storage.has_rowptr()
    assert out.storage.num_cached_keys() == 5