1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
|
import torch
from torch import Tensor
from torch_geometric.utils import is_undirected, to_undirected
def test_is_undirected():
row = torch.tensor([0, 1, 0])
col = torch.tensor([1, 0, 0])
sym_weight = torch.tensor([0, 0, 1])
asym_weight = torch.tensor([0, 1, 1])
assert is_undirected(torch.stack([row, col], dim=0))
assert is_undirected(torch.stack([row, col], dim=0), sym_weight)
assert not is_undirected(torch.stack([row, col], dim=0), asym_weight)
row = torch.tensor([0, 1, 1])
col = torch.tensor([1, 0, 2])
assert not is_undirected(torch.stack([row, col], dim=0))
@torch.jit.script
def jit(edge_index: Tensor) -> bool:
return is_undirected(edge_index)
assert not jit(torch.stack([row, col], dim=0))
def test_to_undirected():
row = torch.tensor([0, 1, 1])
col = torch.tensor([1, 0, 2])
edge_index = to_undirected(torch.stack([row, col], dim=0))
assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]
@torch.jit.script
def jit(edge_index: Tensor) -> Tensor:
return to_undirected(edge_index)
assert torch.equal(jit(torch.stack([row, col], dim=0)), edge_index)
|