File: test_subgraph.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (124 lines) | stat: -rw-r--r-- 4,302 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch

from torch_geometric.nn import GCNConv, Linear
from torch_geometric.testing import withDevice, withPackage
from torch_geometric.utils import (
    bipartite_subgraph,
    get_num_hops,
    index_to_mask,
    k_hop_subgraph,
    subgraph,
)


def test_get_num_hops():
    class GNN(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = GCNConv(3, 16, normalize=False)
            self.conv2 = GCNConv(16, 16, normalize=False)
            self.lin = Linear(16, 2)

        def forward(self, x, edge_index):
            x = torch.F.relu(self.conv1(x, edge_index))
            x = self.conv2(x, edge_index)
            return self.lin(x)

    assert get_num_hops(GNN()) == 2


def test_subgraph():
    edge_index = torch.tensor([
        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],
        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5],
    ])
    edge_attr = torch.tensor(
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0])

    idx = torch.tensor([3, 4, 5])
    mask = index_to_mask(idx, 7)
    indices = idx.tolist()

    for subset in [idx, mask, indices]:
        out = subgraph(subset, edge_index, edge_attr, return_edge_mask=True)
        assert out[0].tolist() == [[3, 4, 4, 5], [4, 3, 5, 4]]
        assert out[1].tolist() == [7.0, 8.0, 9.0, 10.0]
        assert out[2].tolist() == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]

        out = subgraph(subset, edge_index, edge_attr, relabel_nodes=True)
        assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
        assert out[1].tolist() == [7, 8, 9, 10]


@withDevice
@withPackage('pandas')
def test_subgraph_large_index(device):
    subset = torch.tensor([50_000_000], device=device)
    edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)
    edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True)
    assert edge_index.tolist() == [[0], [0]]


def test_bipartite_subgraph():
    edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6],
                               [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]])
    edge_attr = torch.tensor(
        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0])
    idx = (torch.tensor([2, 3, 5]), torch.tensor([2, 3]))
    mask = (index_to_mask(idx[0], 7), index_to_mask(idx[1], 4))
    indices = (idx[0].tolist(), idx[1].tolist())
    mixed = (mask[0], idx[1])

    for subset in [idx, mask, indices, mixed]:
        out = bipartite_subgraph(subset, edge_index, edge_attr,
                                 return_edge_mask=True)
        assert out[0].tolist() == [[2, 3, 5, 5], [3, 2, 2, 3]]
        assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0]
        assert out[2].tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]

        out = bipartite_subgraph(subset, edge_index, edge_attr,
                                 relabel_nodes=True)
        assert out[0].tolist() == [[0, 1, 2, 2], [1, 0, 0, 1]]
        assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0]


@withDevice
@withPackage('pandas')
def test_bipartite_subgraph_large_index(device):
    subset = torch.tensor([50_000_000], device=device)
    edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)

    edge_index, _ = bipartite_subgraph(
        (subset, subset),
        edge_index,
        relabel_nodes=True,
    )
    assert edge_index.tolist() == [[0], [0]]


def test_k_hop_subgraph():
    edge_index = torch.tensor([
        [0, 1, 2, 3, 4, 5],
        [2, 2, 4, 4, 6, 6],
    ])

    subset, edge_index, mapping, edge_mask = k_hop_subgraph(
        6, 2, edge_index, relabel_nodes=True)
    assert subset.tolist() == [2, 3, 4, 5, 6]
    assert edge_index.tolist() == [[0, 1, 2, 3], [2, 2, 4, 4]]
    assert mapping.tolist() == [4]
    assert edge_mask.tolist() == [False, False, True, True, True, True]

    edge_index = torch.tensor([
        [1, 2, 4, 5],
        [0, 1, 5, 6],
    ])

    subset, edge_index, mapping, edge_mask = k_hop_subgraph([0, 6], 2,
                                                            edge_index,
                                                            relabel_nodes=True)

    assert subset.tolist() == [0, 1, 2, 4, 5, 6]
    assert edge_index.tolist() == [[1, 2, 3, 4], [0, 1, 4, 5]]
    assert mapping.tolist() == [0, 5]
    assert edge_mask.tolist() == [True, True, True, True]