1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
|
import torch
from torch_sparse import SparseTensor
neighbor_sample = torch.ops.torch_sparse.neighbor_sample
def test_neighbor_sample():
adj = SparseTensor.from_edge_index(torch.tensor([[0], [1]]))
colptr, row, _ = adj.csc()
# Sampling in a non-directed way should not sample in wrong direction:
out = neighbor_sample(colptr, row, torch.tensor([0]), [1], False, False)
assert out[0].tolist() == [0]
assert out[1].tolist() == []
assert out[2].tolist() == []
# Sampling should work:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
# Sampling with more hops:
out = neighbor_sample(colptr, row, torch.tensor([1]), [1, 1], False, False)
assert out[0].tolist() == [1, 0]
assert out[1].tolist() == [1]
assert out[2].tolist() == [0]
def test_neighbor_sample_seed():
colptr = torch.tensor([0, 3, 6, 9])
row = torch.tensor([0, 1, 2, 0, 1, 2, 0, 1, 2])
input_nodes = torch.tensor([0, 1])
torch.manual_seed(42)
out1 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
torch.manual_seed(42)
out2 = neighbor_sample(colptr, row, input_nodes, [1, 1], True, False)
for data1, data2 in zip(out1, out2):
assert data1.tolist() == data2.tolist()
|