File: test_logsumexp.py

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (31 lines) | stat: -rw-r--r-- 753 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch_scatter import scatter_logsumexp


def test_logsumexp():
    inputs = torch.tensor([
        0.5,
        0.5,
        0.0,
        -2.1,
        3.2,
        7.0,
        -1.0,
        -100.0,
    ])
    inputs.requires_grad_()
    index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4])
    splits = [2, 3, 1, 0, 2]

    outputs = scatter_logsumexp(inputs, index)

    for src, out in zip(inputs.split(splits), outputs.unbind()):
        if src.numel() > 0:
            assert out.tolist() == torch.logsumexp(src, dim=0).tolist()
        else:
            assert out.item() == 0.0

    outputs.backward(torch.randn_like(outputs))

    jit = torch.jit.script(scatter_logsumexp)
    assert jit(inputs, index).tolist() == outputs.tolist()