File: test_sampler.py

package info (click to toggle)
pytorch-cluster 1.6.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 648 kB
  • sloc: cpp: 2,076; python: 1,081; sh: 53; makefile: 8
file content (16 lines) | stat: -rw-r--r-- 390 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

from torch_cluster import neighbor_sampler


def test_neighbor_sampler():
    torch.manual_seed(1234)

    start = torch.tensor([0, 1])
    cumdeg = torch.tensor([0, 3, 7])

    e_id = neighbor_sampler(start, cumdeg, size=1.0)
    assert e_id.tolist() == [0, 2, 1, 5, 6, 3, 4]

    e_id = neighbor_sampler(start, cumdeg, size=3)
    assert e_id.tolist() == [1, 0, 2, 4, 5, 6]