File: test_neighbor_sample.py

package info (click to toggle)
pytorch-sparse 0.6.18-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 984 kB
  • sloc: python: 3,646; cpp: 2,444; sh: 54; makefile: 6
file content (42 lines) | stat: -rw-r--r-- 1,374 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch_sparse import SparseTensor

neighbor_sample = torch.ops.torch_sparse.neighbor_sample


def test_neighbor_sample():
    adj = SparseTensor.from_edge_index(torch.tensor([[0], [1]]))
    colptr, row, _ = adj.csc()

    # Sampling in a non-directed way should not sample in wrong direction:
    out = neighbor_sample(colptr, row, torch.tensor([0]), [1], False, False)
    assert out[0].tolist() == [0]
    assert out[1].tolist() == []
    assert out[2].tolist() == []

    # Sampling should work:
    out = neighbor_sample(colptr, row, torch.tensor([1]), [1], False, False)
    assert out[0].tolist() == [1, 0]
    assert out[1].tolist() == [1]
    assert out[2].tolist() == [0]

    # Sampling with more hops:
    out = neighbor_sample(colptr, row, torch.tensor([1]), [1, 1], False, False)
    assert out[0].tolist() == [1, 0]
    assert out[1].tolist() == [1]
    assert out[2].tolist() == [0]


def test_neighbor_sample_seed():
    colptr = torch.tensor([0, 3, 6, 9])
    row = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
    input_nodes = torch.tensor([0, 1])

    torch.manual_seed(42)
    out1 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)

    torch.manual_seed(42)
    out2 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)

    for data1, data2 in zip(out1, out2):
        assert data1.tolist() == data2.tolist()