File: test_std.py

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (18 lines) | stat: -rw-r--r-- 601 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from torch_scatter import scatter_std


def test_std():
    src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float)
    src.requires_grad_()
    index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long)

    out = scatter_std(src, index, dim=-1, unbiased=True)
    std = src.std(dim=-1, unbiased=True)[0]
    expected = torch.tensor([[std, 0], [0, std]])
    assert torch.allclose(out, expected)

    out.backward(torch.randn_like(out))

    jit = torch.jit.script(scatter_std)
    assert jit(src, index, dim=-1, unbiased=True).tolist() == out.tolist()