File: node_property_split.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (162 lines) | stat: -rw-r--r-- 6,075 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Any, Dict, List

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_networkx


@functional_transform('node_property_split')
class NodePropertySplit(BaseTransform):
    r"""Creates a node-level split with distributional shift based on a given
    node property, as proposed in the `"Evaluating Robustness and Uncertainty
    of Graph Models Under Structural Distributional Shifts"
    <https://arxiv.org/abs/2302.13875>`__ paper
    (functional name: :obj:`node_property_split`).

    It splits the nodes in a given graph into five non-intersecting parts
    based on their structural properties.
    This can be used for transductive node prediction tasks with distributional
    shifts.
    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets
    of nodes.
    The ID subset includes training, validation and testing parts, while
    the OOD subset includes validation and testing parts.
    As a result, it creates five associated node mask vectors for each graph,
    three which are for the ID nodes (:obj:`id_train_mask`,
    :obj:`id_val_mask`, :obj:`id_test_mask`), and two which are for the OOD
    nodes (:obj:`ood_val_mask`, :obj:`ood_test_mask`).

    This class implements three particular strategies for inducing
    distributional shifts in a graph — based on **popularity**, **locality**
    or **density**.

    Args:
        property_name (str): The name of the node property to be used
            (:obj:`"popularity"`, :obj:`"locality"`, :obj:`"density"`).
        ratios ([float]): A list of five ratio values for ID training,
            ID validation, ID test, OOD validation and OOD test parts.
            The values must sum to :obj:`1.0`.
        ascending (bool, optional): Whether to sort nodes in ascending order
            of the node property, so that nodes with greater values of the
            property are considered to be OOD (default: :obj:`True`)

    .. code-block:: python

        from torch_geometric.transforms import NodePropertySplit
        from torch_geometric.datasets.graph_generator import ERGraph

        data = ERGraph(num_nodes=1000, edge_prob=0.4)()

        property_name = 'popularity'
        ratios = [0.3, 0.1, 0.1, 0.3, 0.2]
        transform = NodePropertySplit(property_name, ratios)

        data = transform(data)
    """
    def __init__(
        self,
        property_name: str,
        ratios: List[float],
        ascending: bool = True,
    ):
        if property_name not in {'popularity', 'locality', 'density'}:
            raise ValueError(f"Unexpected 'property_name' "
                             f"(got '{property_name}')")

        if len(ratios) != 5:
            raise ValueError(f"'ratios' must contain 5 values "
                             f"(got {len(ratios)})")

        if sum(ratios) != 1.0:
            raise ValueError(f"'ratios' must sum to 1.0 (got {sum(ratios)})")

        self.property_name = property_name
        self.compute_fn = _property_name_to_compute_fn[property_name]
        self.ratios = ratios
        self.ascending = ascending

    def forward(self, data: Data) -> Data:
        G = to_networkx(data, to_undirected=True, remove_self_loops=True)
        property_values = self.compute_fn(G, self.ascending)
        mask_dict = self._mask_nodes_by_property(property_values, self.ratios)

        for key, mask in mask_dict.items():
            data[key] = mask

        return data

    @staticmethod
    def _compute_popularity_property(G: Any, ascending: bool = True) -> Tensor:
        import networkx.algorithms as A

        property_values = torch.tensor(list(A.pagerank(G).values()))
        property_values *= -1 if ascending else 1
        return property_values

    @staticmethod
    def _compute_locality_property(G: Any, ascending: bool = True) -> Tensor:
        import networkx.algorithms as A

        pagerank_values = torch.tensor(list(A.pagerank(G).values()))

        num_nodes = G.number_of_nodes()
        personalization = dict(zip(range(num_nodes), [0.0] * num_nodes))
        personalization[int(pagerank_values.argmax())] = 1.0

        property_values = torch.tensor(
            list(A.pagerank(G, personalization=personalization).values()))
        property_values *= -1 if ascending else 1
        return property_values

    @staticmethod
    def _compute_density_property(G: Any, ascending: bool = True) -> Tensor:
        import networkx.algorithms as A

        property_values = torch.tensor(list(A.clustering(G).values()))
        property_values *= -1 if ascending else 1
        return property_values

    @staticmethod
    def _mask_nodes_by_property(
        property_values: Tensor,
        ratios: List[float],
    ) -> Dict[str, Tensor]:

        num_nodes = property_values.size(0)
        sizes = (num_nodes * torch.tensor(ratios)).round().long()
        sizes[-1] -= sizes.sum() - num_nodes

        perm = torch.randperm(num_nodes)
        id_size = int(sizes[:3].sum())
        perm = perm[property_values[perm].argsort()]
        perm[:id_size] = perm[:id_size][torch.randperm(id_size)]

        node_splits = perm.split(sizes.tolist())
        names = [
            'id_train_mask',
            'id_val_mask',
            'id_test_mask',
            'ood_val_mask',
            'ood_test_mask',
        ]

        split_masks = {}
        for name, node_split in zip(names, node_splits):
            split_mask = torch.zeros(num_nodes, dtype=torch.bool)
            split_mask[node_split] = True
            split_masks[name] = split_mask
        return split_masks

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.property_name})'


_property_name_to_compute_fn = {
    'popularity': NodePropertySplit._compute_popularity_property,
    'locality': NodePropertySplit._compute_locality_property,
    'density': NodePropertySplit._compute_density_property,
}