1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
|
import os
import os.path as osp
from typing import Callable, List, Optional
from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_zip,
)
from torch_geometric.io import fs
class NeuroGraphDataset(InMemoryDataset):
r"""The NeuroGraph benchmark datasets from the
`"NeuroGraph: Benchmarks for Graph Machine Learning in Brain Connectomics"
<https://arxiv.org/abs/2306.06202>`_ paper.
:class:`NeuroGraphDataset` holds a collection of five neuroimaging graph
learning datasets that span multiple categories of demographics, mental
states, and cognitive traits.
See the `documentation
<https://neurograph.readthedocs.io/en/latest/NeuroGraph.html>`_ and the
`Github <https://github.com/Anwar-Said/NeuroGraph>`_ for more details.
+--------------------+---------+----------------------+
| Dataset | #Graphs | Task |
+====================+=========+======================+
| :obj:`HCPTask` | 7,443 | Graph Classification |
+--------------------+---------+----------------------+
| :obj:`HCPGender` | 1,078 | Graph Classification |
+--------------------+---------+----------------------+
| :obj:`HCPAge` | 1,065 | Graph Classification |
+--------------------+---------+----------------------+
| :obj:`HCPFI` | 1,071 | Graph Regression |
+--------------------+---------+----------------------+
| :obj:`HCPWM` | 1,078 | Graph Regression |
+--------------------+---------+----------------------+
Args:
root (str): Root directory where the dataset should be saved.
name (str): The name of the dataset (one of :obj:`"HCPGender"`,
:obj:`"HCPTask"`, :obj:`"HCPAge"`, :obj:`"HCPFI"`,
:obj:`"HCPWM"`).
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
force_reload (bool, optional): Whether to re-process the dataset.
(default: :obj:`False`)
"""
url = 'https://vanderbilt.box.com/shared/static'
filenames = {
'HCPGender': 'r6hlz2arm7yiy6v6981cv2nzq3b0meax.zip',
'HCPTask': '8wzz4y17wpxg2stip7iybtmymnybwvma.zip',
'HCPAge': 'lzzks4472czy9f9vc8aikp7pdbknmtfe.zip',
'HCPWM': 'xtmpa6712fidi94x6kevpsddf9skuoxy.zip',
'HCPFI': 'g2md9h9snh7jh6eeay02k1kr9m4ido9f.zip',
}
def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
force_reload: bool = False,
) -> None:
assert name in self.filenames.keys()
self.name = name
super().__init__(root, transform, pre_transform, pre_filter,
force_reload=force_reload)
self.load(self.processed_paths[0])
@property
def raw_dir(self) -> str:
return osp.join(self.root, self.name, 'raw')
@property
def raw_file_names(self) -> str:
return 'data.pt'
@property
def processed_dir(self) -> str:
return osp.join(self.root, self.name, 'processed')
@property
def processed_file_names(self) -> str:
return 'data.pt'
def download(self) -> None:
url = f'{self.url}/{self.filenames[self.name]}'
path = download_url(url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
os.rename(
osp.join(self.raw_dir, self.name, 'processed', f'{self.name}.pt'),
osp.join(self.raw_dir, 'data.pt'))
fs.rm(osp.join(self.raw_dir, self.name))
def process(self) -> None:
data, slices = fs.torch_load(self.raw_paths[0])
num_samples = slices['x'].size(0) - 1
data_list: List[Data] = []
for i in range(num_samples):
x = data.x[slices['x'][i]:slices['x'][i + 1]]
start = slices['edge_index'][i]
end = slices['edge_index'][i + 1]
edge_index = data.edge_index[:, start:end]
sample = Data(x=x, edge_index=edge_index, y=data.y[i])
if self.pre_filter is not None and not self.pre_filter(sample):
continue
if self.pre_transform is not None:
sample = self.pre_transform(sample)
data_list.append(sample)
self.save(data_list, self.processed_paths[0])
|