1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
|
import torch
from torch_geometric.distributed.utils import remove_duplicates
from torch_geometric.sampler import SamplerOutput
from torch_geometric.testing import onlyDistributedTest
@onlyDistributedTest
def test_remove_duplicates():
node = torch.tensor([0, 1, 2, 3])
out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 7, 3, 8])
out = SamplerOutput(out_node, None, None, None)
src, node, _, _ = remove_duplicates(out, node)
assert src.tolist() == [4, 5, 6, 7, 8]
assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8]
@onlyDistributedTest
def test_remove_duplicates_disjoint():
node = torch.tensor([0, 1, 2, 3])
batch = torch.tensor([0, 1, 2, 3])
out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 6, 7, 3, 8])
out_batch = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])
out = SamplerOutput(out_node, None, None, None, out_batch)
src, node, src_batch, batch = remove_duplicates(out, node, batch,
disjoint=True)
assert src.tolist() == [4, 5, 6, 7, 8]
assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8]
assert src_batch.tolist() == [0, 1, 2, 3, 3]
assert batch.tolist() == [0, 1, 2, 3, 0, 1, 2, 3, 3]
|