1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
|
import pytest
import torch
from torch_geometric import seed_everything
from torch_geometric.datasets import ExplainerDataset
from torch_geometric.datasets.graph_generator import BAGraph
from torch_geometric.datasets.motif_generator import HouseMotif
@pytest.mark.parametrize('graph_generator', [
pytest.param(BAGraph(num_nodes=80, num_edges=5), id='BAGraph'),
])
@pytest.mark.parametrize('motif_generator', [
pytest.param(HouseMotif(), id='HouseMotif'),
'house',
])
def test_explainer_dataset_ba_house(graph_generator, motif_generator):
dataset = ExplainerDataset(graph_generator, motif_generator, num_motifs=2)
assert str(dataset) == ('ExplainerDataset(1, graph_generator='
'BAGraph(num_nodes=80, num_edges=5), '
'motif_generator=HouseMotif(), num_motifs=2)')
assert len(dataset) == 1
data = dataset[0]
assert len(data) == 4
assert data.num_nodes == 80 + (2 * 5)
assert data.edge_index.min() >= 0
assert data.edge_index.max() < data.num_nodes
assert data.y.min() == 0 and data.y.max() == 3
assert data.node_mask.size() == (data.num_nodes, )
assert data.edge_mask.size() == (data.num_edges, )
assert data.node_mask.min() == 0 and data.node_mask.max() == 1
assert data.node_mask.sum() == 2 * 5
assert data.edge_mask.min() == 0 and data.edge_mask.max() == 1
assert data.edge_mask.sum() == 2 * 12
def test_explainer_dataset_reproducibility():
seed_everything(12345)
data1 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(),
num_motifs=2)[0]
seed_everything(12345)
data2 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(),
num_motifs=2)[0]
assert torch.equal(data1.edge_index, data2.edge_index)
|