1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
from typing import Any, Dict, List
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx
@functional_transform('node_property_split')
class NodePropertySplit(BaseTransform):
r"""Creates a node-level split with distributional shift based on a given
node property, as proposed in the `"Evaluating Robustness and Uncertainty
of Graph Models Under Structural Distributional Shifts"
<https://arxiv.org/abs/2302.13875>`__ paper
(functional name: :obj:`node_property_split`).
It splits the nodes in a given graph into five non-intersecting parts
based on their structural properties.
This can be used for transductive node prediction tasks with distributional
shifts.
It considers the in-distribution (ID) and out-of-distribution (OOD) subsets
of nodes.
The ID subset includes training, validation and testing parts, while
the OOD subset includes validation and testing parts.
As a result, it creates five associated node mask vectors for each graph,
three which are for the ID nodes (:obj:`id_train_mask`,
:obj:`id_val_mask`, :obj:`id_test_mask`), and two which are for the OOD
nodes (:obj:`ood_val_mask`, :obj:`ood_test_mask`).
This class implements three particular strategies for inducing
distributional shifts in a graph — based on **popularity**, **locality**
or **density**.
Args:
property_name (str): The name of the node property to be used
(:obj:`"popularity"`, :obj:`"locality"`, :obj:`"density"`).
ratios ([float]): A list of five ratio values for ID training,
ID validation, ID test, OOD validation and OOD test parts.
The values must sum to :obj:`1.0`.
ascending (bool, optional): Whether to sort nodes in ascending order
of the node property, so that nodes with greater values of the
property are considered to be OOD (default: :obj:`True`)
.. code-block:: python
from torch_geometric.transforms import NodePropertySplit
from torch_geometric.datasets.graph_generator import ERGraph
data = ERGraph(num_nodes=1000, edge_prob=0.4)()
property_name = 'popularity'
ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
transform = NodePropertySplit(property_name, ratios)
data = transform(data)
"""
def __init__(
self,
property_name: str,
ratios: List[float],
ascending: bool = True,
):
if property_name not in {'popularity', 'locality', 'density'}:
raise ValueError(f"Unexpected 'property_name' "
f"(got '{property_name}')")
if len(ratios) != 5:
raise ValueError(f"'ratios' must contain 5 values "
f"(got {len(ratios)})")
if sum(ratios) != 1.0:
raise ValueError(f"'ratios' must sum to 1.0 (got {sum(ratios)})")
self.property_name = property_name
self.compute_fn = _property_name_to_compute_fn[property_name]
self.ratios = ratios
self.ascending = ascending
def forward(self, data: Data) -> Data:
G = to_networkx(data, to_undirected=True, remove_self_loops=True)
property_values = self.compute_fn(G, self.ascending)
mask_dict = self._mask_nodes_by_property(property_values, self.ratios)
for key, mask in mask_dict.items():
data[key] = mask
return data
@staticmethod
def _compute_popularity_property(G: Any, ascending: bool = True) -> Tensor:
import networkx.algorithms as A
property_values = torch.tensor(list(A.pagerank(G).values()))
property_values *= -1 if ascending else 1
return property_values
@staticmethod
def _compute_locality_property(G: Any, ascending: bool = True) -> Tensor:
import networkx.algorithms as A
pagerank_values = torch.tensor(list(A.pagerank(G).values()))
num_nodes = G.number_of_nodes()
personalization = dict(zip(range(num_nodes), [0.0] * num_nodes))
personalization[int(pagerank_values.argmax())] = 1.0
property_values = torch.tensor(
list(A.pagerank(G, personalization=personalization).values()))
property_values *= -1 if ascending else 1
return property_values
@staticmethod
def _compute_density_property(G: Any, ascending: bool = True) -> Tensor:
import networkx.algorithms as A
property_values = torch.tensor(list(A.clustering(G).values()))
property_values *= -1 if ascending else 1
return property_values
@staticmethod
def _mask_nodes_by_property(
property_values: Tensor,
ratios: List[float],
) -> Dict[str, Tensor]:
num_nodes = property_values.size(0)
sizes = (num_nodes * torch.tensor(ratios)).round().long()
sizes[-1] -= sizes.sum() - num_nodes
perm = torch.randperm(num_nodes)
id_size = int(sizes[:3].sum())
perm = perm[property_values[perm].argsort()]
perm[:id_size] = perm[:id_size][torch.randperm(id_size)]
node_splits = perm.split(sizes.tolist())
names = [
'id_train_mask',
'id_val_mask',
'id_test_mask',
'ood_val_mask',
'ood_test_mask',
]
split_masks = {}
for name, node_split in zip(names, node_splits):
split_mask = torch.zeros(num_nodes, dtype=torch.bool)
split_mask[node_split] = True
split_masks[name] = split_mask
return split_masks
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.property_name})'
_property_name_to_compute_fn = {
'popularity': NodePropertySplit._compute_popularity_property,
'locality': NodePropertySplit._compute_locality_property,
'density': NodePropertySplit._compute_density_property,
}
|