1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import RandomJitter
def test_random_jitter():
assert str(RandomJitter(0.1)) == 'RandomJitter(0.1)'
pos = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
data = Data(pos=pos)
data = RandomJitter(0)(data)
assert len(data) == 1
assert torch.allclose(data.pos, pos)
data = Data(pos=pos)
data = RandomJitter(0.1)(data)
assert len(data) == 1
assert data.pos.min() >= -0.1
assert data.pos.max() <= 0.1
data = Data(pos=pos)
data = RandomJitter([0.1, 1])(data)
assert len(data) == 1
assert data.pos[:, 0].min() >= -0.1
assert data.pos[:, 0].max() <= 0.1
assert data.pos[:, 1].min() >= -1
assert data.pos[:, 1].max() <= 1
|