1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
|
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr
from torch_scatter.testing import devices, dtypes, tensor
tests = [
{
'src': [1, 2, 3, 4],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'expected': [1, 1, 2, 2, 2, 4],
},
{
'src': [[1, 2], [3, 4], [5, 6], [7, 8]],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'expected': [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4], [7, 8]]
},
{
'src': [[1, 3, 5, 7], [2, 4, 6, 8]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'expected': [[1, 1, 3, 3, 3, 7], [2, 2, 2, 4, 4, 6]],
},
{
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'expected': [[[1, 2], [1, 2], [3, 4]], [[7, 9], [12, 13], [12, 13]]],
},
{
'src': [[1], [2]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'expected': [[1, 1], [2, 2]],
},
{
'src': [[[1, 1]], [[2, 2]]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'expected': [[[1, 1], [1, 1]], [[2, 2], [2, 2]]],
},
]
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
out = gather_csr(src, indptr)
assert torch.all(out == expected)
out = gather_coo(src, index)
assert torch.all(out == expected)
@pytest.mark.parametrize('test,device', product(tests, devices))
def test_backward(test, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(gather_csr, (src, indptr, None)) is True
assert gradcheck(gather_coo, (src, index, None)) is True
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_out(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
size = list(src.size())
size[index.dim() - 1] = index.size(-1)
out = src.new_full(size, -2)
gather_csr(src, indptr, out)
assert torch.all(out == expected)
out.fill_(-2)
gather_coo(src, index, out)
assert torch.all(out == expected)
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = gather_csr(src, indptr)
assert torch.all(out == expected)
out = gather_coo(src, index)
assert torch.all(out == expected)
|