1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
|
from itertools import product
import pytest
import torch
from torch_scatter import scatter
from torch_scatter.testing import devices, reductions
@pytest.mark.parametrize('reduce,device', product(reductions, devices))
def test_broadcasting(reduce, device):
B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
|