File: test_gcn_norm.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (47 lines) | stat: -rw-r--r-- 1,838 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch

import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.transforms import GCNNorm
from torch_geometric.typing import SparseTensor


def test_gcn_norm():
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    edge_weight = torch.ones(edge_index.size(1))

    transform = GCNNorm()
    assert str(transform) == 'GCNNorm(add_self_loops=True)'

    expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]
    expected_edge_weight = torch.tensor(
        [0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])

    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)
    data = transform(data)
    assert len(data) == 3
    assert data.num_nodes == 3
    assert data.edge_index.tolist() == expected_edge_index
    assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

    data = Data(edge_index=edge_index, num_nodes=3)
    data = transform(data)
    assert len(data) == 3
    assert data.num_nodes == 3
    assert data.edge_index.tolist() == expected_edge_index
    assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)

    # For `SparseTensor`, expected outputs will be sorted:
    if torch_geometric.typing.WITH_TORCH_SPARSE:
        expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]
        expected_edge_weight = torch.tensor(
            [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])

        adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()
        data = Data(adj_t=adj_t)
        data = transform(data)
        assert len(data) == 1
        row, col, value = data.adj_t.coo()
        assert row.tolist() == expected_edge_index[0]
        assert col.tolist() == expected_edge_index[1]
        assert torch.allclose(value, expected_edge_weight, atol=1e-4)