File: test_static_graph.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (27 lines) | stat: -rw-r--r-- 872 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch

from torch_geometric.data import Batch, Data
from torch_geometric.nn import ChebConv, GCNConv, MessagePassing


class MyConv(MessagePassing):
    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)


def test_static_graph():
    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])
    x1, x2 = torch.randn(3, 8), torch.randn(3, 8)

    data1 = Data(edge_index=edge_index, x=x1)
    data2 = Data(edge_index=edge_index, x=x2)
    batch = Batch.from_data_list([data1, data2])

    x = torch.stack([x1, x2], dim=0)
    for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]:
        out1 = conv(batch.x, batch.edge_index)
        assert out1.size(0) == 6
        conv.node_dim = 1
        out2 = conv(x, edge_index)
        assert out2.size()[:2] == (2, 3)
        assert torch.allclose(out1, out2.view(-1, out2.size(-1)))