File: _select.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (68 lines) | stat: -rw-r--r-- 2,439 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Any, List, Union

import torch
from torch import Tensor

from torch_geometric.typing import TensorFrame
from torch_geometric.utils.mask import mask_select
from torch_geometric.utils.sparse import is_torch_sparse_tensor


def select(
    src: Union[Tensor, List[Any], TensorFrame],
    index_or_mask: Tensor,
    dim: int,
) -> Union[Tensor, List[Any]]:
    r"""Selects the input tensor or input list according to a given index or
    mask vector.

    Args:
        src (torch.Tensor or list): The input tensor or list.
        index_or_mask (torch.Tensor): The index or mask vector.
        dim (int): The dimension along which to select.
    """
    if isinstance(src, Tensor):
        if index_or_mask.dtype == torch.bool:
            return mask_select(src, dim, index_or_mask)
        return src.index_select(dim, index_or_mask)

    if isinstance(src, (tuple, list)):
        if dim != 0:
            raise ValueError("Cannot select along dimension other than 0")
        if index_or_mask.dtype == torch.bool:
            return [src[i] for i, m in enumerate(index_or_mask) if m]
        return [src[i] for i in index_or_mask]

    if isinstance(src, TensorFrame):
        assert dim == 0
        if index_or_mask.dtype == torch.bool:
            return mask_select(src, dim, index_or_mask)
        return src[index_or_mask]

    raise ValueError(f"Encountered invalid input type (got '{type(src)}')")


def narrow(src: Union[Tensor, List[Any]], dim: int, start: int,
           length: int) -> Union[Tensor, List[Any]]:
    r"""Narrows the input tensor or input list to the specified range.

    Args:
        src (torch.Tensor or list): The input tensor or list.
        dim (int): The dimension along which to narrow.
        start (int): The starting dimension.
        length (int): The distance to the ending dimension.
    """
    if isinstance(src, Tensor) and is_torch_sparse_tensor(src):
        # TODO Sparse tensors in `torch.sparse` do not yet support `narrow`.
        index = torch.arange(start, start + length, device=src.device)
        return src.index_select(dim, index)

    if isinstance(src, Tensor):
        return src.narrow(dim, start, length)

    if isinstance(src, list):
        if dim != 0:
            raise ValueError("Cannot narrow along dimension other than 0")
        return src[start:start + length]

    raise ValueError(f"Encountered invalid input type (got '{type(src)}')")