File: test_broadcasting.py

package info (click to toggle)
pytorch-scatter 2.1.2-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 1,128 kB
  • sloc: python: 1,574; cpp: 1,379; sh: 58; makefile: 13
file content (26 lines) | stat: -rw-r--r-- 944 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from itertools import product

import pytest
import torch
from torch_scatter import scatter
from torch_scatter.testing import devices, reductions


@pytest.mark.parametrize('reduce,device', product(reductions, devices))
def test_broadcasting(reduce, device):
    B, C, H, W = (4, 3, 8, 8)

    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (H, )).to(device, torch.long)
    out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
    assert out.size() == (B, C, H, W)

    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
    out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
    assert out.size() == (B, C, H, W)

    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (H, )).to(device, torch.long)
    out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
    assert out.size() == (B, C, H, W)