1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
import torch
from torch_geometric.data import Data
from torch_geometric.transforms import GenerateMeshNormals
def test_generate_mesh_normals():
transform = GenerateMeshNormals()
assert str(transform) == 'GenerateMeshNormals()'
pos = torch.tensor([
[0.0, 0.0, 0.0],
[-2.0, 1.0, 0.0],
[-1.0, 1.0, 0.0],
[0.0, 1.0, 0.0],
[1.0, 1.0, 0.0],
[2.0, 1.0, 0.0],
])
face = torch.tensor([
[0, 0, 0, 0],
[1, 2, 3, 4],
[2, 3, 4, 5],
])
data = transform(Data(pos=pos, face=face))
assert len(data) == 3
assert data.pos.tolist() == pos.tolist()
assert data.face.tolist() == face.tolist()
assert data.norm.tolist() == [[0.0, 0.0, -1.0]] * 6
|