1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
|
from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.data.storage import NodeStorage
from torch_geometric.transforms import BaseTransform
@functional_transform('random_node_split')
class RandomNodeSplit(BaseTransform):
r"""Performs a node-level random split by adding :obj:`train_mask`,
:obj:`val_mask` and :obj:`test_mask` attributes to the
:class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData` object
(functional name: :obj:`random_node_split`).
Args:
split (str, optional): The type of dataset split (:obj:`"train_rest"`,
:obj:`"test_rest"`, :obj:`"random"`).
If set to :obj:`"train_rest"`, all nodes except those in the
validation and test sets will be used for training (as in the
`"FastGCN: Fast Learning with Graph Convolutional Networks via
Importance Sampling" <https://arxiv.org/abs/1801.10247>`_ paper).
If set to :obj:`"test_rest"`, all nodes except those in the
training and validation sets will be used for test (as in the
`"Pitfalls of Graph Neural Network Evaluation"
<https://arxiv.org/abs/1811.05868>`_ paper).
If set to :obj:`"random"`, train, validation, and test sets will be
randomly generated, according to :obj:`num_train_per_class`,
:obj:`num_val` and :obj:`num_test` (as in the `"Semi-supervised
Classification with Graph Convolutional Networks"
<https://arxiv.org/abs/1609.02907>`_ paper).
(default: :obj:`"train_rest"`)
num_splits (int, optional): The number of splits to add. If bigger
than :obj:`1`, the shape of masks will be
:obj:`[num_nodes, num_splits]`, and :obj:`[num_nodes]` otherwise.
(default: :obj:`1`)
num_train_per_class (int, optional): The number of training samples
per class in case of :obj:`"test_rest"` and :obj:`"random"` split.
(default: :obj:`20`)
num_val (int or float, optional): The number of validation samples.
If float, it represents the ratio of samples to include in the
validation set. (default: :obj:`500`)
num_test (int or float, optional): The number of test samples in case
of :obj:`"train_rest"` and :obj:`"random"` split. If float, it
represents the ratio of samples to include in the test set.
(default: :obj:`1000`)
key (str, optional): The name of the attribute holding ground-truth
labels. By default, will only add node-level splits for node-level
storages in which :obj:`key` is present. (default: :obj:`"y"`).
"""
def __init__(
self,
split: str = "train_rest",
num_splits: int = 1,
num_train_per_class: int = 20,
num_val: Union[int, float] = 500,
num_test: Union[int, float] = 1000,
key: Optional[str] = "y",
) -> None:
assert split in ['train_rest', 'test_rest', 'random']
self.split = split
self.num_splits = num_splits
self.num_train_per_class = num_train_per_class
self.num_val = num_val
self.num_test = num_test
self.key = key
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.node_stores:
if self.key is not None and not hasattr(store, self.key):
continue
train_masks, val_masks, test_masks = zip(
*[self._split(store) for _ in range(self.num_splits)])
store.train_mask = torch.stack(train_masks, dim=-1).squeeze(-1)
store.val_mask = torch.stack(val_masks, dim=-1).squeeze(-1)
store.test_mask = torch.stack(test_masks, dim=-1).squeeze(-1)
return data
def _split(self, store: NodeStorage) -> Tuple[Tensor, Tensor, Tensor]:
num_nodes = store.num_nodes
assert num_nodes is not None
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)
if isinstance(self.num_val, float):
num_val = round(num_nodes * self.num_val)
else:
num_val = self.num_val
if isinstance(self.num_test, float):
num_test = round(num_nodes * self.num_test)
else:
num_test = self.num_test
if self.split == 'train_rest':
perm = torch.randperm(num_nodes)
val_mask[perm[:num_val]] = True
test_mask[perm[num_val:num_val + num_test]] = True
train_mask[perm[num_val + num_test:]] = True
else:
assert self.key is not None
y = getattr(store, self.key)
num_classes = int(y.max().item()) + 1
for c in range(num_classes):
idx = (y == c).nonzero(as_tuple=False).view(-1)
idx = idx[torch.randperm(idx.size(0))]
idx = idx[:self.num_train_per_class]
train_mask[idx] = True
remaining = (~train_mask).nonzero(as_tuple=False).view(-1)
remaining = remaining[torch.randperm(remaining.size(0))]
val_mask[remaining[:num_val]] = True
if self.split == 'test_rest':
test_mask[remaining[num_val:]] = True
elif self.split == 'random':
test_mask[remaining[num_val:num_val + num_test]] = True
return train_mask, val_mask, test_mask
def __repr__(self) -> str:
return f'{self.__class__.__name__}(split={self.split})'
|