File: test_dataset_summary.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (99 lines) | stat: -rw-r--r-- 3,447 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from torch import Tensor

from torch_geometric.data.summary import Stats, Summary
from torch_geometric.datasets import FakeDataset, FakeHeteroDataset
from torch_geometric.testing import withPackage


def check_stats(stats: Stats, expected: Tensor):
    expected = expected.to(torch.float)
    assert stats.mean == float(expected.mean())
    assert stats.std == float(expected.std())
    assert stats.min == float(expected.min())
    assert stats.quantile25 == float(expected.quantile(0.25))
    assert stats.median == float(expected.median())
    assert stats.quantile75 == float(expected.quantile(0.75))
    assert stats.max == float(expected.max())


def test_dataset_summary():
    dataset = FakeDataset(num_graphs=10)
    num_nodes = torch.tensor([data.num_nodes for data in dataset])
    num_edges = torch.tensor([data.num_edges for data in dataset])

    summary = dataset.get_summary()

    assert summary.name == 'FakeDataset'
    assert summary.num_graphs == 10

    check_stats(summary.num_nodes, num_nodes)
    check_stats(summary.num_edges, num_edges)


@withPackage('tabulate')
def test_dataset_summary_representation():
    dataset = FakeDataset(num_graphs=10)

    summary1 = Summary.from_dataset(dataset, per_type=False)
    summary2 = Summary.from_dataset(dataset, per_type=True)

    assert str(summary1) == str(summary2)


@withPackage('tabulate')
def test_dataset_summary_hetero():
    dataset1 = FakeHeteroDataset(num_graphs=10)
    summary1 = Summary.from_dataset(dataset1, per_type=False)

    dataset2 = [data.to_homogeneous() for data in dataset1]
    summary2 = Summary.from_dataset(dataset2)
    summary2.name = 'FakeHeteroDataset'

    assert summary1 == summary2
    assert str(summary1) == str(summary2)


@withPackage('tabulate')
def test_dataset_summary_hetero_representation_length():
    dataset = FakeHeteroDataset(num_graphs=10)
    summary = Summary.from_dataset(dataset)
    num_lines = len(str(summary).splitlines())

    stats_len = len(Stats.__dataclass_fields__)
    len_header_and_border = 5
    num_tables = 3  # general, stats per node type, stats per edge type

    assert num_lines == num_tables * (stats_len + len_header_and_border)


def test_dataset_summary_hetero_per_type_check():
    dataset = FakeHeteroDataset(num_graphs=10)
    exp_num_nodes = torch.tensor([data.num_nodes for data in dataset])
    exp_num_edges = torch.tensor([data.num_edges for data in dataset])

    summary = dataset.get_summary()

    assert summary.name == 'FakeHeteroDataset'
    assert summary.num_graphs == 10

    check_stats(summary.num_nodes, exp_num_nodes)
    check_stats(summary.num_edges, exp_num_edges)

    num_nodes_per_type = {}
    for node_type in dataset.node_types:
        num_nodes = [data[node_type].num_nodes for data in dataset]
        num_nodes_per_type[node_type] = torch.tensor(num_nodes)

    assert len(summary.num_nodes_per_type) == len(dataset.node_types)
    for node_type, stats in summary.num_nodes_per_type.items():
        check_stats(stats, num_nodes_per_type[node_type])

    num_edges_per_type = {}
    for edge_type in dataset.edge_types:
        num_edges = [data[edge_type].num_edges for data in dataset]
        num_edges_per_type[edge_type] = torch.tensor(num_edges)

    assert len(summary.num_edges_per_type) == len(dataset.edge_types)
    for edge_type, stats in summary.num_edges_per_type.items():
        check_stats(stats, num_edges_per_type[edge_type])