1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
from typing import List, Optional, Sequence, Union
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.data.storage import BaseStorage
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import index_to_mask, mask_to_index
AnyData = Union[Data, HeteroData]
def get_attrs_with_suffix(
attrs: Optional[List[str]],
store: BaseStorage,
suffix: str,
) -> List[str]:
if attrs is not None:
return attrs
return [key for key in store.keys() if key.endswith(suffix)]
def get_mask_size(attr: str, store: BaseStorage, size: Optional[int]) -> int:
if size is not None:
return size
return store.num_edges if store.is_edge_attr(attr) else store.num_nodes
@functional_transform('index_to_mask')
class IndexToMask(BaseTransform):
r"""Converts indices to a mask representation
(functional name: :obj:`index_to_mask`).
Args:
attrs (str, [str], optional): If given, will only perform index to mask
conversion for the given attributes. If omitted, will infer the
attributes from the suffix :obj:`_index`. (default: :obj:`None`)
sizes (int, [int], optional): The size of the mask. If set to
:obj:`None`, an automatically sized tensor is returned. The number
of nodes will be used by default, except for edge attributes which
will use the number of edges as the mask size.
(default: :obj:`None`)
replace (bool, optional): if set to :obj:`True` replaces the index
attributes with mask tensors. (default: :obj:`False`)
"""
def __init__(
self,
attrs: Optional[Union[str, List[str]]] = None,
sizes: Optional[Union[int, List[int]]] = None,
replace: bool = False,
) -> None:
self.attrs = [attrs] if isinstance(attrs, str) else attrs
self.sizes = sizes
self.replace = replace
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.stores:
attrs = get_attrs_with_suffix(self.attrs, store, '_index')
sizes: Sequence[Optional[int]]
if isinstance(self.sizes, int):
sizes = [self.sizes] * len(attrs)
elif isinstance(self.sizes, (list, tuple)):
if len(attrs) != len(self.sizes):
raise ValueError(
f"The number of attributes (got {len(attrs)}) must "
f"match the number of sizes provided "
f"(got {len(self.sizes)})")
sizes = self.sizes
else:
sizes = [None] * len(attrs)
for attr, size in zip(attrs, sizes):
if 'edge_index' in attr:
continue
if attr not in store:
continue
size = get_mask_size(attr, store, size)
mask = index_to_mask(store[attr], size=size)
store[f'{attr[:-6]}_mask'] = mask
if self.replace:
del store[attr]
return data
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(attrs={self.attrs}, '
f'sizes={self.sizes}, replace={self.replace})')
@functional_transform('mask_to_index')
class MaskToIndex(BaseTransform):
r"""Converts a mask to an index representation
(functional name: :obj:`mask_to_index`).
Args:
attrs (str, [str], optional): If given, will only perform mask to index
conversion for the given attributes. If omitted, will infer the
attributes from the suffix :obj:`_mask` (default: :obj:`None`)
replace (bool, optional): if set to :obj:`True` replaces the mask
attributes with index tensors. (default: :obj:`False`)
"""
def __init__(
self,
attrs: Optional[Union[str, List[str]]] = None,
replace: bool = False,
):
self.attrs = [attrs] if isinstance(attrs, str) else attrs
self.replace = replace
def forward(
self,
data: Union[Data, HeteroData],
) -> Union[Data, HeteroData]:
for store in data.stores:
attrs = get_attrs_with_suffix(self.attrs, store, '_mask')
for attr in attrs:
if attr not in store:
continue
index = mask_to_index(store[attr])
store[f'{attr[:-5]}_index'] = index
if self.replace:
del store[attr]
return data
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(attrs={self.attrs}, '
f'replace={self.replace})')
|