File: mask.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (132 lines) | stat: -rw-r--r-- 4,813 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from typing import List, Optional, Sequence, Union

from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.data.storage import BaseStorage
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import index_to_mask, mask_to_index

AnyData = Union[Data, HeteroData]


def get_attrs_with_suffix(
    attrs: Optional[List[str]],
    store: BaseStorage,
    suffix: str,
) -> List[str]:
    if attrs is not None:
        return attrs
    return [key for key in store.keys() if key.endswith(suffix)]


def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
    if size is not None:
        return size
    return store.num_edges if store.is_edge_attr(attr) else store.num_nodes


@functional_transform('index_to_mask')
class IndexToMask(BaseTransform):
    r"""Converts indices to a mask representation
    (functional name: :obj:`index_to_mask`).

    Args:
        attrs (str, [str], optional): If given, will only perform index to mask
            conversion for the given attributes. If omitted, will infer the
            attributes from the suffix :obj:`_index`. (default: :obj:`None`)
        sizes (int, [int], optional): The size of the mask. If set to
            :obj:`None`, an automatically sized tensor is returned. The number
            of nodes will be used by default, except for edge attributes which
            will use the number of edges as the mask size.
            (default: :obj:`None`)
        replace (bool, optional): if set to :obj:`True` replaces the index
            attributes with mask tensors. (default: :obj:`False`)
    """
    def __init__(
        self,
        attrs: Optional[Union[str, List[str]]] = None,
        sizes: Optional[Union[int, List[int]]] = None,
        replace: bool = False,
    ) -> None:
        self.attrs = [attrs] if isinstance(attrs, str) else attrs
        self.sizes = sizes
        self.replace = replace

    def forward(
        self,
        data: Union[Data, HeteroData],
    ) -> Union[Data, HeteroData]:
        for store in data.stores:
            attrs = get_attrs_with_suffix(self.attrs, store, '_index')

            sizes: Sequence[Optional[int]]
            if isinstance(self.sizes, int):
                sizes = [self.sizes] * len(attrs)
            elif isinstance(self.sizes, (list, tuple)):
                if len(attrs) != len(self.sizes):
                    raise ValueError(
                        f"The number of attributes (got {len(attrs)}) must "
                        f"match the number of sizes provided "
                        f"(got {len(self.sizes)})")
                sizes = self.sizes
            else:
                sizes = [None] * len(attrs)

            for attr, size in zip(attrs, sizes):
                if 'edge_index' in attr:
                    continue
                if attr not in store:
                    continue
                size = get_mask_size(attr, store, size)
                mask = index_to_mask(store[attr], size=size)
                store[f'{attr[:-6]}_mask'] = mask
                if self.replace:
                    del store[attr]

        return data

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(attrs={self.attrs}, '
                f'sizes={self.sizes}, replace={self.replace})')


@functional_transform('mask_to_index')
class MaskToIndex(BaseTransform):
    r"""Converts a mask to an index representation
    (functional name: :obj:`mask_to_index`).

    Args:
        attrs (str, [str], optional): If given, will only perform mask to index
            conversion for the given attributes.  If omitted, will infer the
            attributes from the suffix :obj:`_mask` (default: :obj:`None`)
        replace (bool, optional): if set to :obj:`True` replaces the mask
            attributes with index tensors. (default: :obj:`False`)
    """
    def __init__(
        self,
        attrs: Optional[Union[str, List[str]]] = None,
        replace: bool = False,
    ):
        self.attrs = [attrs] if isinstance(attrs, str) else attrs
        self.replace = replace

    def forward(
        self,
        data: Union[Data, HeteroData],
    ) -> Union[Data, HeteroData]:
        for store in data.stores:
            attrs = get_attrs_with_suffix(self.attrs, store, '_mask')

            for attr in attrs:
                if attr not in store:
                    continue
                index = mask_to_index(store[attr])
                store[f'{attr[:-5]}_index'] = index
                if self.replace:
                    del store[attr]

        return data

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(attrs={self.attrs}, '
                f'replace={self.replace})')