1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
import os.path as osp
import pytest
import torch
from torch_geometric.testing import onlyGraphviz, withPackage
from torch_geometric.visualization import visualize_graph
@onlyGraphviz
@pytest.mark.parametrize('backend', [None, 'graphviz'])
def test_visualize_graph_via_graphviz(tmp_path, backend):
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()
path = osp.join(tmp_path, 'graph.pdf')
visualize_graph(edge_index, edge_weight, path, backend)
assert osp.exists(path)
@onlyGraphviz
@pytest.mark.parametrize('backend', [None, 'graphviz'])
def test_visualize_graph_via_graphviz_with_node_labels(tmp_path, backend):
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()
node_labels = ['A', 'B', 'C', 'D', 'E']
path = osp.join(tmp_path, 'graph.pdf')
visualize_graph(edge_index, edge_weight, path, backend, node_labels)
assert osp.exists(path)
@withPackage('networkx', 'matplotlib')
@pytest.mark.parametrize('backend', [None, 'networkx'])
def test_visualize_graph_via_networkx(tmp_path, backend):
edge_index = torch.tensor([
[0, 1, 1, 2, 2, 3, 3, 4],
[1, 0, 2, 1, 3, 2, 4, 3],
])
edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()
path = osp.join(tmp_path, 'graph.pdf')
visualize_graph(edge_index, edge_weight, path, backend)
assert osp.exists(path)
|