File: test_trim_to_layer.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (199 lines) | stat: -rw-r--r-- 6,693 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
from typing import List, Optional

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GraphConv
from torch_geometric.testing import withPackage
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import trim_to_layer
from torch_geometric.utils._trim_to_layer import trim_sparse_tensor


@withPackage('torch_sparse')
def test_trim_sparse_tensor():
    edge_index = torch.tensor([[0, 0, 1, 2], [1, 2, 3, 4]])
    adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[5, 5])

    adj = trim_sparse_tensor(adj, size=(3, 3), num_seed_nodes=1)

    row, col, _ = adj.coo()
    assert row.tolist() == [0, 0]
    assert col.tolist() == [1, 2]


def test_trim_to_layer_basic():
    x0 = torch.arange(4)
    edge_index0 = torch.tensor([[1, 2, 3], [0, 1, 2]])
    edge_weight0 = torch.arange(3)

    num_sampled_nodes_per_hop = [1, 1, 1]
    num_sampled_edges_per_hop = [1, 1, 1]

    x1, edge_index1, edge_weight1 = trim_to_layer(
        layer=0,
        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
        num_sampled_edges_per_hop=num_sampled_edges_per_hop,
        x=x0,
        edge_index=edge_index0,
        edge_attr=edge_weight0,
    )
    assert torch.equal(x1, torch.arange(4))
    assert edge_index1.tolist() == [[1, 2, 3], [0, 1, 2]]
    assert torch.equal(edge_weight1, torch.arange(3))

    if torch_geometric.typing.WITH_TORCH_SPARSE:
        adj0 = SparseTensor.from_edge_index(edge_index0, edge_weight0, (4, 4))
        x1, adj_t1, _ = trim_to_layer(
            layer=0,
            num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
            num_sampled_edges_per_hop=num_sampled_edges_per_hop,
            x=x0,
            edge_index=adj0.t(),
            edge_attr=edge_weight0,
        )
        adj1 = adj_t1.t()
        assert adj1.sizes() == [4, 4]

        row, col, value = adj1.coo()
        assert torch.equal(x1, torch.arange(4))
        assert row.tolist() == [1, 2, 3]
        assert col.tolist() == [0, 1, 2]
        assert torch.equal(value, torch.arange(3))

    x2, edge_index2, edge_weight2 = trim_to_layer(
        layer=1,
        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
        num_sampled_edges_per_hop=num_sampled_edges_per_hop,
        x=x1,
        edge_index=edge_index1,
        edge_attr=edge_weight1,
    )
    assert torch.equal(x2, torch.arange(3))
    assert edge_index2.tolist() == [[1, 2], [0, 1]]
    assert torch.equal(edge_weight2, torch.arange(2))

    if torch_geometric.typing.WITH_TORCH_SPARSE:
        adj1 = SparseTensor.from_edge_index(edge_index1, edge_weight1, (4, 4))
        x2, adj_t2, _ = trim_to_layer(
            layer=1,
            num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
            num_sampled_edges_per_hop=num_sampled_edges_per_hop,
            x=x1,
            edge_index=adj1.t(),
        )
        adj2 = adj_t2.t()
        assert adj2.sizes() == [3, 3]

        row, col, value = adj2.coo()
        assert torch.equal(x2, torch.arange(3))
        assert row.tolist() == [1, 2]
        assert col.tolist() == [0, 1]
        assert torch.equal(value, torch.arange(2))

    x3, edge_index3, edge_weight3 = trim_to_layer(
        layer=2,
        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
        num_sampled_edges_per_hop=num_sampled_edges_per_hop,
        x=x2,
        edge_index=edge_index2,
        edge_attr=edge_weight2,
    )
    assert torch.equal(x3, torch.arange(2))
    assert edge_index3.tolist() == [[1], [0]]
    assert torch.equal(edge_weight3, torch.arange(1))

    if torch_geometric.typing.WITH_TORCH_SPARSE:
        adj2 = SparseTensor.from_edge_index(edge_index2, edge_weight2, (3, 3))
        x3, adj_t3, _ = trim_to_layer(
            layer=2,
            num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
            num_sampled_edges_per_hop=num_sampled_edges_per_hop,
            x=x2,
            edge_index=adj2.t(),
        )
        adj3 = adj_t3.t()
        assert adj3.sizes() == [2, 2]

        row, col, value = adj3.coo()
        assert torch.equal(x3, torch.arange(2))
        assert row.tolist() == [1]
        assert col.tolist() == [0]
        assert torch.equal(value, torch.arange(1))


def test_trim_to_layer_hetero():
    x = {'v': torch.arange(4)}
    edge_index = {('v', 'to', 'v'): torch.tensor([[1, 2, 3], [0, 1, 2]])}
    edge_weight = {('v', 'to', 'v'): torch.arange(3)}

    num_sampled_nodes_per_hop = {'v': [1, 1, 1, 1]}
    num_sampled_edges_per_hop = {('v', 'to', 'v'): [1, 1, 1]}

    x, edge_index, edge_weight = trim_to_layer(
        layer=1,
        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,
        num_sampled_edges_per_hop=num_sampled_edges_per_hop,
        x=x,
        edge_index=edge_index,
        edge_attr=edge_weight,
    )
    assert torch.equal(x['v'], torch.arange(3))
    assert edge_index['v', 'to', 'v'].tolist() == [[1, 2], [0, 1]]
    assert torch.equal(edge_weight['v', 'to', 'v'], torch.arange(2))


class GNN(torch.nn.Module):
    def __init__(self, num_layers: int):
        super().__init__()

        self.convs = torch.nn.ModuleList(
            GraphConv(16, 16) for _ in range(num_layers))

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
        edge_weight: Tensor,
        num_sampled_nodes: Optional[List[int]] = None,
        num_sampled_edges: Optional[List[int]] = None,
    ) -> Tensor:
        for i, conv in enumerate(self.convs):
            if num_sampled_nodes is not None:
                x, edge_index, edge_weight = trim_to_layer(
                    i, num_sampled_nodes, num_sampled_edges, x, edge_index,
                    edge_weight)
            x = conv(x, edge_index, edge_weight)
        return x


@withPackage('pyg_lib')
def test_trim_to_layer_with_neighbor_loader():
    x = torch.randn(14, 16)
    edge_index = torch.tensor([
        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],
        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],
    ])
    edge_weight = torch.rand(edge_index.size(1))
    data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)

    loader = NeighborLoader(
        data,
        num_neighbors=[1, 2, 4],
        batch_size=2,
        shuffle=False,
    )
    batch = next(iter(loader))

    model = GNN(num_layers=3)
    out1 = model(batch.x, batch.edge_index, batch.edge_weight)[:2]
    assert out1.size() == (2, 16)

    out2 = model(batch.x, batch.edge_index, batch.edge_weight,
                 batch.num_sampled_nodes, batch.num_sampled_edges)[:2]
    assert out2.size() == (2, 16)

    assert torch.allclose(out1, out2, atol=1e-6)