1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
|
import torch
from torch_geometric.utils import to_torch_coo_tensor
from torch_geometric.utils.num_nodes import (
maybe_num_nodes,
maybe_num_nodes_dict,
)
def test_maybe_num_nodes():
edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]])
assert maybe_num_nodes(edge_index, 4) == 4
assert maybe_num_nodes(edge_index) == 3
adj = to_torch_coo_tensor(edge_index)
assert maybe_num_nodes(adj, 4) == 4
assert maybe_num_nodes(adj) == 3
def test_maybe_num_nodes_dict():
edge_index_dict = {
'1': torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]),
'2': torch.tensor([[0, 0, 1, 3], [1, 2, 0, 4]])
}
num_nodes_dict = {'2': 6}
assert maybe_num_nodes_dict(edge_index_dict) == {'1': 3, '2': 5}
assert maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) == {
'1': 3,
'2': 6,
}
|