File: test_sampler_base.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (118 lines) | stat: -rw-r--r-- 4,011 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import pytest
import torch

from torch_geometric.sampler.base import (
    HeteroSamplerOutput,
    NumNeighbors,
    SamplerOutput,
)
from torch_geometric.testing import get_random_edge_index
from torch_geometric.utils import is_undirected


def test_homogeneous_num_neighbors():
    with pytest.raises(ValueError, match="'default' must be set to 'None'"):
        num_neighbors = NumNeighbors([25, 10], default=[-1, -1])

    num_neighbors = NumNeighbors([25, 10])
    assert str(num_neighbors) == 'NumNeighbors(values=[25, 10], default=None)'

    assert num_neighbors.get_values() == [25, 10]
    assert num_neighbors.__dict__['_values'] == [25, 10]
    assert num_neighbors.get_values() == [25, 10]  # Test caching.

    assert num_neighbors.get_mapped_values() == [25, 10]
    assert num_neighbors.__dict__['_mapped_values'] == [25, 10]
    assert num_neighbors.get_mapped_values() == [25, 10]  # Test caching.

    assert num_neighbors.num_hops == 2
    assert num_neighbors.__dict__['_num_hops'] == 2
    assert num_neighbors.num_hops == 2  # Test caching.


def test_heterogeneous_num_neighbors_list():
    num_neighbors = NumNeighbors([25, 10])

    values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])
    assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]}

    values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])
    assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]}

    assert num_neighbors.num_hops == 2


def test_heterogeneous_num_neighbors_dict_and_default():
    num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1])
    with pytest.raises(ValueError, match="hops must be the same across all"):
        values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])

    num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1])

    values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])
    assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]}

    values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])
    assert values == {'A__to__B': [25, 10], 'B__to__A': [-1, -1]}

    assert num_neighbors.num_hops == 2


def test_heterogeneous_num_neighbors_empty_dict():
    num_neighbors = NumNeighbors({}, default=[25, 10])

    values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])
    assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]}

    values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])
    assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]}

    assert num_neighbors.num_hops == 2


def test_homogeneous_to_bidirectional():
    edge_index = get_random_edge_index(10, 10, num_edges=20)

    obj = SamplerOutput(
        node=torch.arange(10),
        row=edge_index[0],
        col=edge_index[0],
        edge=torch.arange(edge_index.size(1)),
    ).to_bidirectional()

    assert is_undirected(torch.stack([obj.row, obj.col], dim=0))


def test_heterogeneous_to_bidirectional():
    edge_index1 = get_random_edge_index(10, 5, num_edges=20)
    edge_index2 = get_random_edge_index(5, 10, num_edges=20)
    edge_index3 = get_random_edge_index(10, 10, num_edges=20)

    obj = HeteroSamplerOutput(
        node={
            'v1': torch.arange(10),
            'v2': torch.arange(5)
        },
        row={
            ('v1', 'to', 'v2'): edge_index1[0],
            ('v2', 'rev_to', 'v1'): edge_index2[0],
            ('v1', 'to', 'v1'): edge_index3[0],
        },
        col={
            ('v1', 'to', 'v2'): edge_index1[1],
            ('v2', 'rev_to', 'v1'): edge_index2[1],
            ('v1', 'to', 'v1'): edge_index3[1],
        },
        edge={},
    ).to_bidirectional()

    assert torch.equal(
        obj.row['v1', 'to', 'v2'].sort().values,
        obj.col['v2', 'rev_to', 'v1'].sort().values,
    )
    assert torch.equal(
        obj.col['v1', 'to', 'v2'].sort().values,
        obj.row['v2', 'rev_to', 'v1'].sort().values,
    )
    assert is_undirected(
        torch.stack([obj.row['v1', 'to', 'v1'], obj.col['v1', 'to', 'v1']], 0))