1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
|
from typing import Dict, List, Optional, Tuple
from torch import Tensor
from torch_geometric.data import EdgeAttr, GraphStore
from torch_geometric.typing import EdgeTensorType
class MyGraphStore(GraphStore):
def __init__(self) -> None:
super().__init__()
self.store: Dict[Tuple, Tuple[Tensor, Tensor]] = {}
@staticmethod
def key(attr: EdgeAttr) -> Tuple:
return (attr.edge_type, attr.layout.value, attr.is_sorted, attr.size)
def _put_edge_index(
self,
edge_index: EdgeTensorType,
edge_attr: EdgeAttr,
) -> bool:
self.store[self.key(edge_attr)] = edge_index
return True
def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:
return self.store.get(self.key(edge_attr), None)
def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:
return self.store.pop(self.key(edge_attr), None) is not None
def get_all_edge_attrs(self) -> List[EdgeAttr]:
return [EdgeAttr(*key) for key in self.store.keys()]
|