1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
|
from math import sqrt
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import NormalizeRotation
def test_normalize_rotation():
assert str(NormalizeRotation()) == 'NormalizeRotation()'
pos = torch.tensor([
[-2.0, -2.0],
[-1.0, -1.0],
[0.0, 0.0],
[1.0, 1.0],
[2.0, 2.0],
])
normal = torch.tensor([
[-1.0, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
])
data = Data(pos=pos)
data.normal = normal
data = NormalizeRotation()(data)
assert len(data) == 2
expected_pos = torch.tensor([
[-2 * sqrt(2), 0.0],
[-sqrt(2), 0.0],
[0.0, 0.0],
[sqrt(2), 0.0],
[2 * sqrt(2), 0.0],
])
expected_normal = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]
assert torch.allclose(data.pos, expected_pos, atol=1e-04)
assert data.normal.tolist() == expected_normal
data = Data(pos=pos)
data.normal = normal
data = NormalizeRotation(max_points=3)(data)
assert len(data) == 2
assert torch.allclose(data.pos, expected_pos, atol=1e-04)
assert data.normal.tolist() == expected_normal
data = Data(pos=pos)
data.normal = normal
data = NormalizeRotation(sort=True)(data)
assert len(data) == 2
assert torch.allclose(data.pos, expected_pos, atol=1e-04)
assert data.normal.tolist() == expected_normal
|