File: test_rooted_subgraph.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (89 lines) | stat: -rw-r--r-- 2,984 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.testing import withPackage
from torch_geometric.transforms import RootedEgoNets, RootedRWSubgraph


def test_rooted_ego_nets():
    x = torch.randn(3, 8)
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    edge_attr = torch.randn(4, 8)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    transform = RootedEgoNets(num_hops=1)
    assert str(transform) == 'RootedEgoNets(num_hops=1)'

    out = transform(data)
    assert len(out) == 8

    assert torch.equal(out.x, data.x)
    assert torch.equal(out.edge_index, data.edge_index)
    assert torch.equal(out.edge_attr, data.edge_attr)

    assert out.sub_edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6],
                                           [1, 0, 3, 2, 4, 3, 6, 5]]
    assert out.n_id.tolist() == [0, 1, 0, 1, 2, 1, 2]
    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2]
    assert out.e_id.tolist() == [0, 1, 0, 1, 2, 3, 2, 3]
    assert out.e_sub_batch.tolist() == [0, 0, 1, 1, 1, 1, 2, 2]

    out = out.map_data()
    assert len(out) == 4

    assert torch.allclose(out.x, x[[0, 1, 0, 1, 2, 1, 2]])
    assert out.edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6],
                                       [1, 0, 3, 2, 4, 3, 6, 5]]
    assert torch.allclose(out.edge_attr, edge_attr[[0, 1, 0, 1, 2, 3, 2, 3]])
    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2]


@withPackage('torch_cluster')
def test_rooted_rw_subgraph():
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    data = Data(edge_index=edge_index, num_nodes=3)

    transform = RootedRWSubgraph(walk_length=1)
    assert str(transform) == 'RootedRWSubgraph(walk_length=1)'

    out = transform(data)
    assert len(out) == 7

    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2]
    assert out.sub_edge_index.size() == (2, 6)

    out = out.map_data()
    assert len(out) == 3

    assert out.edge_index.size() == (2, 6)
    assert out.num_nodes == 6
    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2]


def test_rooted_subgraph_minibatch():
    x = torch.randn(3, 8)
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    edge_attr = torch.randn(4, 8)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    transform = RootedEgoNets(num_hops=1)
    data = transform(data)

    loader = DataLoader([data, data], batch_size=2)
    batch = next(iter(loader))
    batch = batch.map_data()
    assert batch.num_graphs == len(batch) == 2

    assert batch.x.size() == (14, 8)
    assert batch.edge_index.size() == (2, 16)
    assert batch.edge_attr.size() == (16, 8)
    assert batch.n_sub_batch.size() == (14, )
    assert batch.batch.size() == (14, )
    assert batch.ptr.size() == (3, )

    assert batch.edge_index.min() == 0
    assert batch.edge_index.max() == 13

    assert batch.n_sub_batch.min() == 0
    assert batch.n_sub_batch.max() == 5