File: test_to_dense_adj.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (108 lines) | stat: -rw-r--r-- 3,740 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch

from torch_geometric.testing import is_full_test
from torch_geometric.utils import to_dense_adj


def test_to_dense_adj():
    edge_index = torch.tensor([
        [0, 0, 1, 2, 3, 4],
        [0, 1, 0, 3, 4, 2],
    ])
    batch = torch.tensor([0, 0, 1, 1, 1])

    adj = to_dense_adj(edge_index, batch)
    assert adj.size() == (2, 3, 3)
    assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]
    assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]

    if is_full_test():
        jit = torch.jit.script(to_dense_adj)
        adj = jit(edge_index, batch)
        assert adj.size() == (2, 3, 3)
        assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]
        assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]

    adj = to_dense_adj(edge_index, batch, max_num_nodes=2)
    assert adj.size() == (2, 2, 2)
    assert adj[0].tolist() == [[1, 1], [1, 0]]
    assert adj[1].tolist() == [[0, 1], [0, 0]]

    adj = to_dense_adj(edge_index, batch, max_num_nodes=5)
    assert adj.size() == (2, 5, 5)
    assert adj[0][:3, :3].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]
    assert adj[1][:3, :3].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]

    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
    adj = to_dense_adj(edge_index, batch, edge_attr)
    assert adj.size() == (2, 3, 3)
    assert adj[0].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]]
    assert adj[1].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]]

    adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5)
    assert adj.size() == (2, 5, 5)
    assert adj[0][:3, :3].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]]
    assert adj[1][:3, :3].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]]

    edge_attr = edge_attr.view(-1, 1)
    adj = to_dense_adj(edge_index, batch, edge_attr)
    assert adj.size() == (2, 3, 3, 1)

    edge_attr = edge_attr.view(-1, 1)
    adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5)
    assert adj.size() == (2, 5, 5, 1)

    adj = to_dense_adj(edge_index)
    assert adj.size() == (1, 5, 5)
    assert adj[0].nonzero(as_tuple=False).t().tolist() == edge_index.tolist()

    adj = to_dense_adj(edge_index, max_num_nodes=10)
    assert adj.size() == (1, 10, 10)
    assert adj[0].nonzero(as_tuple=False).t().tolist() == edge_index.tolist()

    adj = to_dense_adj(edge_index, batch, batch_size=4)
    assert adj.size() == (4, 3, 3)


def test_to_dense_adj_with_empty_edge_index():
    edge_index = torch.tensor([[], []], dtype=torch.long)
    batch = torch.tensor([0, 0, 1, 1, 1])

    adj = to_dense_adj(edge_index)
    assert adj.size() == (1, 0, 0)

    adj = to_dense_adj(edge_index, max_num_nodes=10)
    assert adj.size() == (1, 10, 10) and adj.sum() == 0

    adj = to_dense_adj(edge_index, batch)
    assert adj.size() == (2, 3, 3) and adj.sum() == 0

    adj = to_dense_adj(edge_index, batch, max_num_nodes=10)
    assert adj.size() == (2, 10, 10) and adj.sum() == 0


def test_to_dense_adj_with_duplicate_entries():
    edge_index = torch.tensor([
        [0, 0, 0, 1, 2, 3, 3, 4],
        [0, 0, 1, 0, 3, 4, 4, 2],
    ])
    batch = torch.tensor([0, 0, 1, 1, 1])

    adj = to_dense_adj(edge_index, batch)
    assert adj.size() == (2, 3, 3)
    assert adj[0].tolist() == [[2, 1, 0], [1, 0, 0], [0, 0, 0]]
    assert adj[1].tolist() == [[0, 1, 0], [0, 0, 2], [1, 0, 0]]

    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
    adj = to_dense_adj(edge_index, batch, edge_attr)
    assert adj.size() == (2, 3, 3)
    assert adj[0].tolist() == [
        [3.0, 3.0, 0.0],
        [4.0, 0.0, 0.0],
        [0.0, 0.0, 0.0],
    ]
    assert adj[1].tolist() == [
        [0.0, 5.0, 0.0],
        [0.0, 0.0, 13.0],
        [8.0, 0.0, 0.0],
    ]