File: ba2motif_dataset.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (126 lines) | stat: -rw-r--r-- 4,258 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import pickle
from typing import Callable, List, Optional

import torch

from torch_geometric.data import Data, InMemoryDataset, download_url


class BA2MotifDataset(InMemoryDataset):
    r"""The synthetic BA-2motifs graph classification dataset for evaluating
    explainabilty algorithms, as described in the `"Parameterized Explainer
    for Graph Neural Network" <https://arxiv.org/abs/2011.04573>`_ paper.
    :class:`~torch_geometric.datasets.BA2MotifDataset` contains 1000 random
    Barabasi-Albert (BA) graphs.
    Half of the graphs are attached with a
    :class:`~torch_geometric.datasets.motif_generator.HouseMotif`, and the rest
    are attached with a five-node
    :class:`~torch_geometric.datasets.motif_generator.CycleMotif`.
    The graphs are assigned to one of the two classes according to the type of
    attached motifs.

    This dataset is pre-computed from the official implementation. If you want
    to create own variations of it, you can make use of the
    :class:`~torch_geometric.datasets.ExplainerDataset`:

    .. code-block:: python

        import torch
        from torch_geometric.datasets import ExplainerDataset
        from torch_geometric.datasets.graph_generator import BAGraph
        from torch_geometric.datasets.motif_generator import HouseMotif
        from torch_geometric.datasets.motif_generator import CycleMotif

        dataset1 = ExplainerDataset(
            graph_generator=BAGraph(num_nodes=25, num_edges=1),
            motif_generator=HouseMotif(),
            num_motifs=1,
            num_graphs=500,
        )

        dataset2 = ExplainerDataset(
            graph_generator=BAGraph(num_nodes=25, num_edges=1),
            motif_generator=CycleMotif(5),
            num_motifs=1,
            num_graphs=500,
        )

        dataset = torch.utils.data.ConcatDataset([dataset1, dataset2])

    Args:
        root (str): Root directory where the dataset should be saved.
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        force_reload (bool, optional): Whether to re-process the dataset.
            (default: :obj:`False`)

    **STATS:**

    .. list-table::
        :widths: 10 10 10 10 10
        :header-rows: 1

        * - #graphs
          - #nodes
          - #edges
          - #features
          - #classes
        * - 1000
          - 25
          - ~51.0
          - 10
          - 2
    """
    url = 'https://github.com/flyingdoog/PGExplainer/raw/master/dataset'
    filename = 'BA-2motif.pkl'

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        pre_transform: Optional[Callable] = None,
        force_reload: bool = False,
    ) -> None:
        super().__init__(root, transform, pre_transform,
                         force_reload=force_reload)
        self.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> str:
        return self.filename

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def download(self) -> None:
        download_url(f'{self.url}/{self.filename}', self.raw_dir)

    def process(self) -> None:
        with open(self.raw_paths[0], 'rb') as f:
            adj, x, y = pickle.load(f)

        adjs = torch.from_numpy(adj)
        xs = torch.from_numpy(x).to(torch.float)
        ys = torch.from_numpy(y)

        data_list: List[Data] = []
        for i in range(xs.size(0)):
            edge_index = adjs[i].nonzero().t()
            x = xs[i]
            y = int(ys[i].nonzero())

            data = Data(x=x, edge_index=edge_index, y=y)

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        self.save(data_list, self.processed_paths[0])