1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
|
from itertools import product
import pytest
import torch
from torch import Tensor
from torch_cluster import fps
from torch_cluster.testing import devices, grad_dtypes, tensor
@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_fps(dtype, device):
x = tensor([
[-1, -1],
[-1, +1],
[+1, +1],
[+1, -1],
[-2, -2],
[-2, +2],
[+2, +2],
[+2, -2],
], dtype, device)
batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)
ptr_list = [0, 4, 8]
ptr = torch.tensor(ptr_list, device=device)
out = fps(x, batch, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
ratio = torch.tensor(0.5, device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
ratio = torch.tensor([0.5, 0.5], device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, ratio=torch.tensor([0.5], device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
N = 1024
for _ in range(5):
pos = torch.randn((2 * N, 3), device=device)
batch_1 = torch.zeros(N, dtype=torch.long, device=device)
batch_2 = torch.ones(N, dtype=torch.long, device=device)
batch = torch.cat([batch_1, batch_2])
idx = fps(pos, batch, ratio=0.5)
assert idx.min() >= 0 and idx.max() < 2 * N
|