File: test_remote_backend_utils.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (33 lines) | stat: -rw-r--r-- 1,245 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pytest
import torch

from torch_geometric.data import HeteroData
from torch_geometric.data.remote_backend_utils import num_nodes, size
from torch_geometric.testing import (
    MyFeatureStore,
    MyGraphStore,
    get_random_edge_index,
)


@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])
@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])
def test_num_nodes_size(FeatureStore, GraphStore):
    feature_store = FeatureStore()
    graph_store = GraphStore()

    # Infer num nodes from features:
    x = torch.arange(100)
    feature_store.put_tensor(x, group_name='x', attr_name='x', index=None)
    assert num_nodes(feature_store, graph_store, 'x') == 100

    # Infer num nodes and size from edges:
    xy = get_random_edge_index(100, 50, 20)
    graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout='coo',
                               size=(100, 50))
    assert num_nodes(feature_store, graph_store, 'y') == 50
    assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50)

    # Throw an error if we cannot infer for an unknown node type:
    with pytest.raises(ValueError, match="Unable to accurately infer"):
        _ = num_nodes(feature_store, graph_store, 'z')