1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
import torch
import torch_geometric.typing
from torch_geometric.data import Data
from torch_geometric.transforms import GCNNorm
from torch_geometric.typing import SparseTensor
def test_gcn_norm():
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
edge_weight = torch.ones(edge_index.size(1))
transform = GCNNorm()
assert str(transform) == 'GCNNorm(add_self_loops=True)'
expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]
expected_edge_weight = torch.tensor(
[0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])
data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert data.edge_index.tolist() == expected_edge_index
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)
data = Data(edge_index=edge_index, num_nodes=3)
data = transform(data)
assert len(data) == 3
assert data.num_nodes == 3
assert data.edge_index.tolist() == expected_edge_index
assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)
# For `SparseTensor`, expected outputs will be sorted:
if torch_geometric.typing.WITH_TORCH_SPARSE:
expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]
expected_edge_weight = torch.tensor(
[0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])
adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()
data = Data(adj_t=adj_t)
data = transform(data)
assert len(data) == 1
row, col, value = data.adj_t.coo()
assert row.tolist() == expected_edge_index[0]
assert col.tolist() == expected_edge_index[1]
assert torch.allclose(value, expected_edge_weight, atol=1e-4)
|