File: infection_dataset.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (166 lines) | stat: -rw-r--r-- 7,293 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import Any, Callable, Dict, List, Optional, Union

import torch

from torch_geometric.data import InMemoryDataset
from torch_geometric.datasets.graph_generator import GraphGenerator
from torch_geometric.explain import Explanation
from torch_geometric.utils import k_hop_subgraph


class InfectionDataset(InMemoryDataset):
    r"""Generates a synthetic infection dataset for evaluating explainabilty
    algorithms, as described in the `"Explainability Techniques for Graph
    Convolutional Networks" <https://arxiv.org/abs/1905.13686>`__ paper.
    The :class:`~torch_geometric.datasets.InfectionDataset` creates synthetic
    graphs coming from a
    :class:`~torch_geometric.datasets.graph_generator.GraphGenerator` with
    :obj:`num_infected` randomly assigned infected nodes.
    The dataset describes a node classification task of predicting the length
    of the shortest path to infected nodes, with corresponding ground-truth
    edge-level masks.

    For example, to generate a random Erdos-Renyi (ER) infection graph
    with :obj:`500` nodes and :obj:`0.004` edge probability, write:

    .. code-block:: python

        from torch_geometric.datasets import InfectionDataset
        from torch_geometric.datasets.graph_generator import ERGraph

        dataset = InfectionDataset(
            graph_generator=ERGraph(num_nodes=500, edge_prob=0.004),
            num_infected_nodes=50,
            max_path_length=3,
        )

    Args:
        graph_generator (GraphGenerator or str): The graph generator to be
            used, *e.g.*,
            :class:`torch.geometric.datasets.graph_generator.BAGraph`
            (or any string that automatically resolves to it).
        num_infected_nodes (int or List[int]): The number of randomly
            selected infected nodes in the graph.
            If given as a list, will select a different number of infected
            nodes for different graphs.
        max_path_length (int, List[int]): The maximum shortest path length to
            determine whether a node will be infected.
            If given as a list, will apply different shortest path lengths for
            different graphs. (default: :obj:`5`)
        num_graphs (int, optional): The number of graphs to generate.
            The number of graphs will be automatically determined by
            :obj:`len(num_infected_nodes)` or :obj:`len(max_path_length)` in
            case either of them is given as a list, and should only be set in
            case one wants to create multiple graphs while
            :obj:`num_infected_nodes` and :obj:`max_path_length` are given as
            an integer. (default: :obj:`None`)
        graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to
            the respective graph generator module in case it gets automatically
            resolved. (default: :obj:`None`)
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
    """
    def __init__(
        self,
        graph_generator: Union[GraphGenerator, str],
        num_infected_nodes: Union[int, List[int]],
        max_path_length: Union[int, List[int]],
        num_graphs: Optional[int] = None,
        graph_generator_kwargs: Optional[Dict[str, Any]] = None,
        transform: Optional[Callable] = None,
    ):
        super().__init__(root=None, transform=transform)

        assert isinstance(num_infected_nodes, (int, list))
        assert isinstance(max_path_length, (int, list))

        if (num_graphs is None and isinstance(num_infected_nodes, int)
                and isinstance(max_path_length, int)):
            num_graphs = 1

        if num_graphs is None and isinstance(num_infected_nodes, list):
            num_graphs = len(num_infected_nodes)

        if num_graphs is None and isinstance(max_path_length, list):
            num_graphs = len(max_path_length)

        assert num_graphs is not None

        self.graph_generator = GraphGenerator.resolve(
            graph_generator,
            **(graph_generator_kwargs or {}),
        )
        self.num_infected_nodes = num_infected_nodes
        self.max_path_length = max_path_length
        self.num_graphs = num_graphs

        if isinstance(num_infected_nodes, int):
            num_infected_nodes = [num_infected_nodes] * num_graphs

        if isinstance(max_path_length, int):
            max_path_length = [max_path_length] * num_graphs

        if len(num_infected_nodes) != num_graphs:
            raise ValueError(f"The length of 'num_infected_nodes' "
                             f"(got {len(num_infected_nodes)} does not match "
                             f"with the number of graphs (got {num_graphs})")

        if len(max_path_length) != num_graphs:
            raise ValueError(f"The length of 'max_path_length' "
                             f"(got {len(max_path_length)} does not match "
                             f"with the number of graphs (got {num_graphs})")

        if any(num_infected_nodes) <= 0:
            raise ValueError(f"'num_infected_nodes' needs to be positive "
                             f"(got {min(num_infected_nodes)})")

        if any(max_path_length) <= 0:
            raise ValueError(f"'max_path_length' needs to be positive "
                             f"(got {min(max_path_length)})")

        data_list: List[Explanation] = []
        for N, L in zip(num_infected_nodes, max_path_length):
            data_list.append(self.get_graph(N, L))

        self.data, self.slices = self.collate(data_list)

    def get_graph(self, num_infected_nodes: int,
                  max_path_length: int) -> Explanation:
        data = self.graph_generator()

        assert data.num_nodes is not None
        perm = torch.randperm(data.num_nodes)
        x = torch.zeros((data.num_nodes, 2))
        x[perm[:num_infected_nodes], 1] = 1  # Infected
        x[perm[num_infected_nodes:], 0] = 1  # Healthy

        y = torch.empty(data.num_nodes, dtype=torch.long)
        y.fill_(max_path_length + 1)
        y[perm[:num_infected_nodes]] = 0  # Infected nodes have label `0`.

        assert data.edge_index is not None
        edge_mask = torch.zeros(data.num_edges, dtype=torch.bool)
        for num_hops in range(1, max_path_length + 1):
            sub_node_index, _, _, sub_edge_mask = k_hop_subgraph(
                perm[:num_infected_nodes], num_hops, data.edge_index,
                num_nodes=data.num_nodes, flow='target_to_source',
                directed=True)

            value = torch.full_like(sub_node_index, fill_value=num_hops)
            y[sub_node_index] = torch.min(y[sub_node_index], value)
            edge_mask |= sub_edge_mask

        return Explanation(
            x=x,
            edge_index=data.edge_index,
            y=y,
            edge_mask=edge_mask.to(torch.float),
        )

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({len(self)}, '
                f'graph_generator={self.graph_generator}, '
                f'num_infected_nodes={self.num_infected_nodes}, '
                f'max_path_length={self.max_path_length})')