File: to_dense.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (68 lines) | stat: -rw-r--r-- 2,456 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


@functional_transform('to_dense')
class ToDense(BaseTransform):
    r"""Converts a sparse adjacency matrix to a dense adjacency matrix with
    shape :obj:`[num_nodes, num_nodes, *]` (functional name: :obj:`to_dense`).

    Args:
        num_nodes (int, optional): The number of nodes. If set to :obj:`None`,
            the number of nodes will get automatically inferred.
            (default: :obj:`None`)
    """
    def __init__(self, num_nodes: Optional[int] = None) -> None:
        self.num_nodes = num_nodes

    def forward(self, data: Data) -> Data:
        assert data.edge_index is not None

        orig_num_nodes = data.num_nodes
        assert orig_num_nodes is not None

        if self.num_nodes is None:
            num_nodes = orig_num_nodes
        else:
            assert orig_num_nodes <= self.num_nodes
            num_nodes = self.num_nodes

        if data.edge_attr is None:
            edge_attr = torch.ones(data.edge_index.size(1), dtype=torch.float)
        else:
            edge_attr = data.edge_attr

        size = torch.Size([num_nodes, num_nodes] + list(edge_attr.size())[1:])
        adj = torch.sparse_coo_tensor(data.edge_index, edge_attr, size)
        data.adj = adj.to_dense()
        data.edge_index = None
        data.edge_attr = None

        data.mask = torch.zeros(num_nodes, dtype=torch.bool)
        data.mask[:orig_num_nodes] = 1

        if data.x is not None:
            _size = [num_nodes - data.x.size(0)] + list(data.x.size())[1:]
            data.x = torch.cat([data.x, data.x.new_zeros(_size)], dim=0)

        if data.pos is not None:
            _size = [num_nodes - data.pos.size(0)] + list(data.pos.size())[1:]
            data.pos = torch.cat([data.pos, data.pos.new_zeros(_size)], dim=0)

        if (data.y is not None and isinstance(data.y, Tensor)
                and data.y.size(0) == orig_num_nodes):
            _size = [num_nodes - data.y.size(0)] + list(data.y.size())[1:]
            data.y = torch.cat([data.y, data.y.new_zeros(_size)], dim=0)

        return data

    def __repr__(self) -> str:
        if self.num_nodes is None:
            return f'{self.__class__.__name__}()'
        return f'{self.__class__.__name__}(num_nodes={self.num_nodes})'