File: test_random_node_loader.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (47 lines) | stat: -rw-r--r-- 1,726 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch

from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import RandomNodeLoader
from torch_geometric.testing import get_random_edge_index


def test_random_node_loader():
    data = Data()
    data.x = torch.randn(100, 128)
    data.node_id = torch.arange(100)
    data.edge_index = get_random_edge_index(100, 100, 500)
    data.edge_attr = torch.randn(500, 32)

    loader = RandomNodeLoader(data, num_parts=4, shuffle=True)
    assert len(loader) == 4

    for batch in loader:
        assert len(batch) == 4
        assert batch.node_id.min() >= 0
        assert batch.node_id.max() < 100
        assert batch.edge_index.size(1) == batch.edge_attr.size(0)
        assert torch.allclose(batch.x, data.x[batch.node_id])
        batch.validate()


def test_heterogeneous_random_node_loader():
    data = HeteroData()
    data['paper'].x = torch.randn(100, 128)
    data['paper'].node_id = torch.arange(100)
    data['author'].x = torch.randn(200, 128)
    data['author'].node_id = torch.arange(200)
    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500)
    data['paper', 'author'].edge_attr = torch.randn(500, 32)
    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400)
    data['author', 'paper'].edge_attr = torch.randn(400, 32)
    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 600)
    data['paper', 'paper'].edge_attr = torch.randn(600, 32)

    loader = RandomNodeLoader(data, num_parts=4, shuffle=True)
    assert len(loader) == 4

    for batch in loader:
        assert len(batch) == 4
        assert batch.node_types == data.node_types
        assert batch.edge_types == data.edge_types
        batch.validate()