File: test_zero_tensors.py

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (36 lines) | stat: -rw-r--r-- 1,231 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from itertools import product

import pytest
import torch
from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo,
                           segment_csr)
from torch_scatter.testing import devices, grad_dtypes, reductions, tensor


@pytest.mark.parametrize('reduce,dtype,device',
                         product(reductions, grad_dtypes, devices))
def test_zero_elements(reduce, dtype, device):
    x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device,
                    requires_grad=True)
    index = tensor([], torch.long, device)
    indptr = tensor([], torch.long, device)

    out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_coo(x, index, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = gather_coo(x, index)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_csr(x, indptr, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = gather_csr(x, indptr)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)