File: test_grid.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (26 lines) | stat: -rw-r--r-- 850 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch

from torch_geometric.testing import is_full_test
from torch_geometric.utils import grid


def test_grid():
    (row, col), pos = grid(height=3, width=2)

    expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2]
    expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5]
    expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]
    expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]

    expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]]

    assert row.tolist() == expected_row
    assert col.tolist() == expected_col
    assert pos.tolist() == expected_pos

    if is_full_test():
        jit = torch.jit.script(grid)
        (row, col), pos = jit(height=3, width=2)
        assert row.tolist() == expected_row
        assert col.tolist() == expected_col
        assert pos.tolist() == expected_pos