File: svd_feature_reduction.py

package info (click to toggle)
pytorch-geometric 2.6.1-7
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 12,904 kB
  • sloc: python: 127,155; sh: 338; cpp: 27; makefile: 18; javascript: 16
file content (30 lines) | stat: -rw-r--r-- 1,010 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform


@functional_transform('svd_feature_reduction')
class SVDFeatureReduction(BaseTransform):
    r"""Dimensionality reduction of node features via Singular Value
    Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`).

    Args:
        out_channels (int): The dimensionality of node features after
            reduction.
    """
    def __init__(self, out_channels: int):
        self.out_channels = out_channels

    def forward(self, data: Data) -> Data:
        assert data.x is not None

        if data.x.size(-1) > self.out_channels:
            U, S, _ = torch.linalg.svd(data.x)
            data.x = torch.mm(U[:, :self.out_channels],
                              torch.diag(S[:self.out_channels]))
        return data

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}({self.out_channels})'