File: grid_graph.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (38 lines) | stat: -rw-r--r-- 1,159 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from typing import Optional

import torch

from torch_geometric.data import Data
from torch_geometric.datasets.graph_generator import GraphGenerator
from torch_geometric.utils import grid


class GridGraph(GraphGenerator):
    r"""Generates two-dimensional grid graphs.
    See :meth:`~torch_geometric.utils.grid` for more information.

    Args:
        height (int): The height of the grid.
        width (int): The width of the grid.
        dtype (:obj:`torch.dtype`, optional): The desired data type of the
            returned position tensor. (default: :obj:`None`)
    """
    def __init__(
        self,
        height: int,
        width: int,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        self.height = height
        self.width = width
        self.dtype = dtype

    def __call__(self) -> Data:
        edge_index, pos = grid(height=self.height, width=self.width,
                               dtype=self.dtype)
        return Data(edge_index=edge_index, pos=pos)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(height={self.height}, '
                f'width={self.width})')