File: test_gather.py

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (112 lines) | stat: -rw-r--r-- 3,618 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from itertools import product

import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr
from torch_scatter.testing import devices, dtypes, tensor

tests = [
    {
        'src': [1, 2, 3, 4],
        'index': [0, 0, 1, 1, 1, 3],
        'indptr': [0, 2, 5, 5, 6],
        'expected': [1, 1, 2, 2, 2, 4],
    },
    {
        'src': [[1, 2], [3, 4], [5, 6], [7, 8]],
        'index': [0, 0, 1, 1, 1, 3],
        'indptr': [0, 2, 5, 5, 6],
        'expected': [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4], [7, 8]]
    },
    {
        'src': [[1, 3, 5, 7], [2, 4, 6, 8]],
        'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
        'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
        'expected': [[1, 1, 3, 3, 3, 7], [2, 2, 2, 4, 4, 6]],
    },
    {
        'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
        'index': [[0, 0, 1], [0, 2, 2]],
        'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
        'expected': [[[1, 2], [1, 2], [3, 4]], [[7, 9], [12, 13], [12, 13]]],
    },
    {
        'src': [[1], [2]],
        'index': [[0, 0], [0, 0]],
        'indptr': [[0, 2], [0, 2]],
        'expected': [[1, 1], [2, 2]],
    },
    {
        'src': [[[1, 1]], [[2, 2]]],
        'index': [[0, 0], [0, 0]],
        'indptr': [[0, 2], [0, 2]],
        'expected': [[[1, 1], [1, 1]], [[2, 2], [2, 2]]],
    },
]


@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(test, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test['expected'], dtype, device)

    out = gather_csr(src, indptr)
    assert torch.all(out == expected)

    out = gather_coo(src, index)
    assert torch.all(out == expected)


@pytest.mark.parametrize('test,device', product(tests, devices))
def test_backward(test, device):
    src = tensor(test['src'], torch.double, device)
    src.requires_grad_()
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)

    assert gradcheck(gather_csr, (src, indptr, None)) is True
    assert gradcheck(gather_coo, (src, index, None)) is True


@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_out(test, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test['expected'], dtype, device)

    size = list(src.size())
    size[index.dim() - 1] = index.size(-1)
    out = src.new_full(size, -2)

    gather_csr(src, indptr, out)
    assert torch.all(out == expected)

    out.fill_(-2)

    gather_coo(src, index, out)
    assert torch.all(out == expected)


@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous(test, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test['expected'], dtype, device)

    if src.dim() > 1:
        src = src.transpose(0, 1).contiguous().transpose(0, 1)
    if index.dim() > 1:
        index = index.transpose(0, 1).contiguous().transpose(0, 1)
    if indptr.dim() > 1:
        indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)

    out = gather_csr(src, indptr)
    assert torch.all(out == expected)

    out = gather_coo(src, index)
    assert torch.all(out == expected)