1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
|
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets.graph_generator import GraphGenerator
from torch_geometric.explain import Explanation
from torch_geometric.utils import k_hop_subgraph
class InfectionDataset(InMemoryDataset):
r"""Generates a synthetic infection dataset for evaluating explainabilty
algorithms, as described in the `"Explainability Techniques for Graph
Convolutional Networks" <https://arxiv.org/abs/1905.13686>`__ paper.
The :class:`~torch_geometric.datasets.InfectionDataset` creates synthetic
graphs coming from a
:class:`~torch_geometric.datasets.graph_generator.GraphGenerator` with
:obj:`num_infected` randomly assigned infected nodes.
The dataset describes a node classification task of predicting the length
of the shortest path to infected nodes, with corresponding ground-truth
edge-level masks.
For example, to generate a random Erdos-Renyi (ER) infection graph
with :obj:`500` nodes and :obj:`0.004` edge probability, write:
.. code-block:: python
from torch_geometric.datasets import InfectionDataset
from torch_geometric.datasets.graph_generator import ERGraph
dataset = InfectionDataset(
graph_generator=ERGraph(num_nodes=500, edge_prob=0.004),
num_infected_nodes=50,
max_path_length=3,
)
Args:
graph_generator (GraphGenerator or str): The graph generator to be
used, *e.g.*,
:class:`torch.geometric.datasets.graph_generator.BAGraph`
(or any string that automatically resolves to it).
num_infected_nodes (int or List[int]): The number of randomly
selected infected nodes in the graph.
If given as a list, will select a different number of infected
nodes for different graphs.
max_path_length (int, List[int]): The maximum shortest path length to
determine whether a node will be infected.
If given as a list, will apply different shortest path lengths for
different graphs. (default: :obj:`5`)
num_graphs (int, optional): The number of graphs to generate.
The number of graphs will be automatically determined by
:obj:`len(num_infected_nodes)` or :obj:`len(max_path_length)` in
case either of them is given as a list, and should only be set in
case one wants to create multiple graphs while
:obj:`num_infected_nodes` and :obj:`max_path_length` are given as
an integer. (default: :obj:`None`)
graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to
the respective graph generator module in case it gets automatically
resolved. (default: :obj:`None`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
"""
def __init__(
self,
graph_generator: Union[GraphGenerator, str],
num_infected_nodes: Union[int, List[int]],
max_path_length: Union[int, List[int]],
num_graphs: Optional[int] = None,
graph_generator_kwargs: Optional[Dict[str, Any]] = None,
transform: Optional[Callable] = None,
):
super().__init__(root=None, transform=transform)
assert isinstance(num_infected_nodes, (int, list))
assert isinstance(max_path_length, (int, list))
if (num_graphs is None and isinstance(num_infected_nodes, int)
and isinstance(max_path_length, int)):
num_graphs = 1
if num_graphs is None and isinstance(num_infected_nodes, list):
num_graphs = len(num_infected_nodes)
if num_graphs is None and isinstance(max_path_length, list):
num_graphs = len(max_path_length)
assert num_graphs is not None
self.graph_generator = GraphGenerator.resolve(
graph_generator,
**(graph_generator_kwargs or {}),
)
self.num_infected_nodes = num_infected_nodes
self.max_path_length = max_path_length
self.num_graphs = num_graphs
if isinstance(num_infected_nodes, int):
num_infected_nodes = [num_infected_nodes] * num_graphs
if isinstance(max_path_length, int):
max_path_length = [max_path_length] * num_graphs
if len(num_infected_nodes) != num_graphs:
raise ValueError(f"The length of 'num_infected_nodes' "
f"(got {len(num_infected_nodes)} does not match "
f"with the number of graphs (got {num_graphs})")
if len(max_path_length) != num_graphs:
raise ValueError(f"The length of 'max_path_length' "
f"(got {len(max_path_length)} does not match "
f"with the number of graphs (got {num_graphs})")
if any(num_infected_nodes) <= 0:
raise ValueError(f"'num_infected_nodes' needs to be positive "
f"(got {min(num_infected_nodes)})")
if any(max_path_length) <= 0:
raise ValueError(f"'max_path_length' needs to be positive "
f"(got {min(max_path_length)})")
data_list: List[Explanation] = []
for N, L in zip(num_infected_nodes, max_path_length):
data_list.append(self.get_graph(N, L))
self.data, self.slices = self.collate(data_list)
def get_graph(self, num_infected_nodes: int,
max_path_length: int) -> Explanation:
data = self.graph_generator()
assert data.num_nodes is not None
perm = torch.randperm(data.num_nodes)
x = torch.zeros((data.num_nodes, 2))
x[perm[:num_infected_nodes], 1] = 1 # Infected
x[perm[num_infected_nodes:], 0] = 1 # Healthy
y = torch.empty(data.num_nodes, dtype=torch.long)
y.fill_(max_path_length + 1)
y[perm[:num_infected_nodes]] = 0 # Infected nodes have label `0`.
assert data.edge_index is not None
edge_mask = torch.zeros(data.num_edges, dtype=torch.bool)
for num_hops in range(1, max_path_length + 1):
sub_node_index, _, _, sub_edge_mask = k_hop_subgraph(
perm[:num_infected_nodes], num_hops, data.edge_index,
num_nodes=data.num_nodes, flow='target_to_source',
directed=True)
value = torch.full_like(sub_node_index, fill_value=num_hops)
y[sub_node_index] = torch.min(y[sub_node_index], value)
edge_mask |= sub_edge_mask
return Explanation(
x=x,
edge_index=data.edge_index,
y=y,
edge_mask=edge_mask.to(torch.float),
)
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({len(self)}, '
f'graph_generator={self.graph_generator}, '
f'num_infected_nodes={self.num_infected_nodes}, '
f'max_path_length={self.max_path_length})')
|