File: create_dataset.rst

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (200 lines) | stat: -rw-r--r-- 9,036 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
Creating Graph Datasets
=======================

Although :pyg:`PyG` already contains a lot of useful datasets, you may wish to create your own dataset with self-recorded or non-publicly available data.

Implementing datasets by yourself is straightforward and you may want to take a look at the source code to find out how the various datasets are implemented.
However, we give a brief introduction on what is needed to setup your own dataset.

We provide two abstract classes for datasets: :class:`torch_geometric.data.Dataset` and :class:`torch_geometric.data.InMemoryDataset`.
:class:`torch_geometric.data.InMemoryDataset` inherits from :class:`torch_geometric.data.Dataset` and should be used if the whole dataset fits into CPU memory.

Following the :obj:`torchvision` convention, each dataset gets passed a root folder which indicates where the dataset should be stored.
We split up the root folder into two folders: the :obj:`raw_dir`, where the dataset gets downloaded to, and the :obj:`processed_dir`, where the processed dataset is being saved.

In addition, each dataset can be passed a :obj:`transform`, a :obj:`pre_transform` and a :obj:`pre_filter` function, which are :obj:`None` by default.
The :obj:`transform` function dynamically transforms the data object before accessing (so it is best used for data augmentation).
The :obj:`pre_transform` function applies the transformation before saving the data objects to disk (so it is best used for heavy precomputation which needs to be only done once).
The :obj:`pre_filter` function can manually filter out data objects before saving.
Use cases may involve the restriction of data objects being of a specific class.

Creating "In Memory Datasets"
-----------------------------

In order to create a :class:`torch_geometric.data.InMemoryDataset`, you need to implement four fundamental methods:

* :func:`torch_geometric.data.InMemoryDataset.raw_file_names`: A list of files in the :obj:`raw_dir` which needs to be found in order to skip the download.

* :func:`torch_geometric.data.InMemoryDataset.processed_file_names`: A list of files in the :obj:`processed_dir` which needs to be found in order to skip the processing.

* :func:`torch_geometric.data.InMemoryDataset.download`: Downloads raw data into :obj:`raw_dir`.

* :func:`torch_geometric.data.InMemoryDataset.process`: Processes raw data and saves it into the :obj:`processed_dir`.

You can find helpful methods to download and extract data in :mod:`torch_geometric.data`.

The real magic happens in the body of :meth:`~torch_geometric.data.InMemoryDataset.process`.
Here, we need to read and create a list of :class:`~torch_geometric.data.Data` objects and save it into the :obj:`processed_dir`.
Because saving a huge python list is quite slow, we collate the list into one huge :class:`~torch_geometric.data.Data` object via :meth:`torch_geometric.data.InMemoryDataset.collate` before saving.
The collated data object concatenates all examples into one big data object and, in addition, returns a :obj:`slices` dictionary to reconstruct single examples from this object.
Finally, we need to load these two objects in the constructor into the properties :obj:`self.data` and :obj:`self.slices`.

.. note::

    From :pyg:`null` **PyG >= 2.4**, the functionalities of :meth:`torch.save` and :meth:`torch_geometric.data.InMemoryDataset.collate` are unified and implemented behind :meth:`torch_geometric.data.InMemoryDataset.save`.
    Additionally, :obj:`self.data` and :obj:`self.slices` are implicitly loaded via :meth:`torch_geometric.data.InMemoryDataset.load`.

Let's see this process in a simplified example:

.. code-block:: python

    import torch
    from torch_geometric.data import InMemoryDataset, download_url


    class MyOwnDataset(InMemoryDataset):
        def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
            super().__init__(root, transform, pre_transform, pre_filter)
            self.load(self.processed_paths[0])
            # For PyG<2.4:
            # self.data, self.slices = torch.load(self.processed_paths[0])

        @property
        def raw_file_names(self):
            return ['some_file_1', 'some_file_2', ...]

        @property
        def processed_file_names(self):
            return ['data.pt']

        def download(self):
            # Download to `self.raw_dir`.
            download_url(url, self.raw_dir)
            ...

        def process(self):
            # Read data into huge `Data` list.
            data_list = [...]

            if self.pre_filter is not None:
                data_list = [data for data in data_list if self.pre_filter(data)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(data) for data in data_list]

            self.save(data_list, self.processed_paths[0])
            # For PyG<2.4:
            # torch.save(self.collate(data_list), self.processed_paths[0])

Creating "Larger" Datasets
--------------------------

For creating datasets which do not fit into memory, the :class:`torch_geometric.data.Dataset` can be used, which closely follows the concepts of the :obj:`torchvision` datasets.
It expects the following methods to be implemented in addition:

* :func:`torch_geometric.data.Dataset.len`: Returns the number of examples in your dataset.

* :func:`torch_geometric.data.Dataset.get`: Implements the logic to load a single graph.

Internally, :meth:`torch_geometric.data.Dataset.__getitem__` gets data objects from :meth:`torch_geometric.data.Dataset.get` and optionally transforms them according to :obj:`transform`.

Let's see this process in a simplified example:

.. code-block:: python

    import os.path as osp

    import torch
    from torch_geometric.data import Dataset, download_url


    class MyOwnDataset(Dataset):
        def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
            super().__init__(root, transform, pre_transform, pre_filter)

        @property
        def raw_file_names(self):
            return ['some_file_1', 'some_file_2', ...]

        @property
        def processed_file_names(self):
            return ['data_1.pt', 'data_2.pt', ...]

        def download(self):
            # Download to `self.raw_dir`.
            path = download_url(url, self.raw_dir)
            ...

        def process(self):
            idx = 0
            for raw_path in self.raw_paths:
                # Read data from `raw_path`.
                data = Data(...)

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
                idx += 1

        def len(self):
            return len(self.processed_file_names)

        def get(self, idx):
            data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
            return data

Here, each graph data object gets saved individually in :meth:`~torch_geometric.data.Dataset.process`, and is manually loaded in :meth:`~torch_geometric.data.Dataset.get`.

Frequently Asked Questions
--------------------------

#. **How can I skip the execution of** :meth:`download` **and/or** :meth:`process` **?**

    You can skip downloading and/or processing by just not overriding the :meth:`download()` and :meth:`process()` methods:

    .. code-block:: python

        class MyOwnDataset(Dataset):
            def __init__(self, transform=None, pre_transform=None):
                super().__init__(None, transform, pre_transform)

#. **Do I really need to use these dataset interfaces?**

    No! Just as in regular :pytorch:`PyTorch`, you do not have to use datasets, *e.g.*, when you want to create synthetic data on the fly without saving them explicitly to disk.
    In this case, simply pass a regular python list holding :class:`torch_geometric.data.Data` objects and pass them to :class:`torch_geometric.loader.DataLoader`:

    .. code-block:: python

        from torch_geometric.data import Data
        from torch_geometric.loader import DataLoader

        data_list = [Data(...), ..., Data(...)]
        loader = DataLoader(data_list, batch_size=32)

Exercises
---------

Consider the following :class:`~torch_geometric.data.InMemoryDataset` constructed from a list of :obj:`~torch_geometric.data.Data` objects:

.. code-block:: python

    class MyDataset(InMemoryDataset):
        def __init__(self, root, data_list, transform=None):
            self.data_list = data_list
            super().__init__(root, transform)
            self.load(self.processed_paths[0])

        @property
        def processed_file_names(self):
            return 'data.pt'

        def process(self):
            self.save(self.data_list, self.processed_paths[0])

1. What is the output of :obj:`self.processed_paths[0]`?

2. What does :meth:`~torch_geometric.data.InMemoryDataset.save` do?