File: test_tree_graph.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (25 lines) | stat: -rw-r--r-- 825 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import pytest

from torch_geometric.datasets.graph_generator import TreeGraph


@pytest.mark.parametrize('undirected', [False, True])
def test_tree_graph(undirected):
    graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected)
    assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, '
                                    f'undirected={undirected})')

    data = graph_generator()
    assert len(data) == 3
    assert data.num_nodes == 7
    assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2]
    if not undirected:
        assert data.edge_index.tolist() == [
            [0, 0, 1, 1, 2, 2],
            [1, 2, 3, 4, 5, 6],
        ]
    else:
        assert data.edge_index.tolist() == [
            [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6],
            [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2],
        ]