File: test_enzymes.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (71 lines) | stat: -rw-r--r-- 2,451 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import pytest
import torch

from torch_geometric.loader import DataListLoader, DataLoader, DenseDataLoader
from torch_geometric.testing import onlyOnline
from torch_geometric.transforms import ToDense


@onlyOnline
def test_enzymes(get_dataset):
    dataset = get_dataset(name='ENZYMES')
    assert len(dataset) == 600
    assert dataset.num_features == 3
    assert dataset.num_classes == 6
    assert str(dataset) == 'ENZYMES(600)'

    assert len(dataset[0]) == 3
    assert len(dataset.shuffle()) == 600
    assert len(dataset.shuffle(return_perm=True)) == 2
    assert len(dataset[:100]) == 100
    assert len(dataset[0.1:0.2]) == 60
    assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100
    mask = torch.zeros(600, dtype=torch.bool)
    mask[:100] = 1
    assert len(dataset[mask]) == 100

    loader = DataLoader(dataset, batch_size=len(dataset))
    for batch in loader:
        assert batch.num_graphs == len(batch) == 600

        avg_num_nodes = batch.num_nodes / batch.num_graphs
        assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63

        avg_num_edges = batch.num_edges / (2 * batch.num_graphs)
        assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14

        assert list(batch.x.size()) == [batch.num_nodes, 3]
        assert list(batch.y.size()) == [batch.num_graphs]
        assert batch.y.max() + 1 == 6
        assert list(batch.batch.size()) == [batch.num_nodes]
        assert batch.ptr.numel() == batch.num_graphs + 1

        assert batch.has_isolated_nodes()
        assert not batch.has_self_loops()
        assert batch.is_undirected()

    loader = DataListLoader(dataset, batch_size=len(dataset))
    for data_list in loader:
        assert len(data_list) == 600

    dataset.transform = ToDense(num_nodes=126)
    loader = DenseDataLoader(dataset, batch_size=len(dataset))
    for data in loader:
        assert list(data.x.size()) == [600, 126, 3]
        assert list(data.adj.size()) == [600, 126, 126]
        assert list(data.mask.size()) == [600, 126]
        assert list(data.y.size()) == [600, 1]


@onlyOnline
def test_enzymes_with_node_attr(get_dataset):
    dataset = get_dataset(name='ENZYMES', use_node_attr=True)
    assert dataset.num_node_features == 21
    assert dataset.num_features == 21
    assert dataset.num_edge_features == 0


@onlyOnline
def test_cleaned_enzymes(get_dataset):
    dataset = get_dataset(name='ENZYMES', cleaned=True)
    assert len(dataset) == 595