File: custom.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (38 lines) | stat: -rw-r--r-- 1,278 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Any, Optional

from torch_geometric.data import Data
from torch_geometric.datasets.motif_generator import MotifGenerator
from torch_geometric.utils import from_networkx


class CustomMotif(MotifGenerator):
    r"""Generates a motif based on a custom structure coming from a
    :class:`torch_geometric.data.Data` or :class:`networkx.Graph` object.

    Args:
        structure (torch_geometric.data.Data or networkx.Graph): The structure
            to use as a motif.
    """
    def __init__(self, structure: Any):
        super().__init__()

        self.structure: Optional[Data] = None

        if isinstance(structure, Data):
            self.structure = structure
        else:
            try:
                import networkx as nx
                if isinstance(structure, nx.Graph):
                    self.structure = from_networkx(structure)
            except ImportError:
                pass

        if self.structure is None:
            raise ValueError(f"Expected a motif structure of type "
                             f"'torch_geometric.data.Data' or 'networkx.Graph'"
                             f"(got {type(structure)})")

    def __call__(self) -> Data:
        assert isinstance(self.structure, Data)
        return self.structure