File: map.py

package info (click to toggle)
pytorch-geometric 2.7.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 14,172 kB
  • sloc: python: 144,911; sh: 247; cpp: 27; makefile: 18; javascript: 16
file content (168 lines) | stat: -rw-r--r-- 6,018 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch.utils.dlpack import from_dlpack

from torch_geometric.warnings import WarningCache

_warning_cache = WarningCache()


def map_index(
    src: Tensor,
    index: Tensor,
    max_index: Optional[Union[int, Tensor]] = None,
    inclusive: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""Maps indices in :obj:`src` to the positional value of their
    corresponding occurrence in :obj:`index`.
    Indices must be strictly positive.

    Args:
        src (torch.Tensor): The source tensor to map.
        index (torch.Tensor): The index tensor that denotes the new mapping.
        max_index (int, optional): The maximum index value.
            (default :obj:`None`)
        inclusive (bool, optional): If set to :obj:`True`, it is assumed that
            every entry in :obj:`src` has a valid entry in :obj:`index`.
            Can speed-up computation. (default: :obj:`False`)

    :rtype: (:class:`torch.Tensor`, :class:`torch.BoolTensor`)

    Examples:
        >>> src = torch.tensor([2, 0, 1, 0, 3])
        >>> index = torch.tensor([3, 2, 0, 1])

        >>> map_index(src, index)
        (tensor([1, 2, 3, 2, 0]), tensor([True, True, True, True, True]))

        >>> src = torch.tensor([2, 0, 1, 0, 3])
        >>> index = torch.tensor([3, 2, 0])

        >>> map_index(src, index)
        (tensor([1, 2, 2, 0]), tensor([True, True, False, True, True]))

    .. note::

        If inputs are on GPU and :obj:`cudf` is available, consider using RMM
        for significant speed boosts.
        Proceed with caution as RMM may conflict with other allocators or
        fragments.

        .. code-block:: python

            import rmm
            rmm.reinitialize(pool_allocator=True)
            torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
    """
    if src.is_floating_point():
        raise ValueError(f"Expected 'src' to be an index (got '{src.dtype}')")
    if index.is_floating_point():
        raise ValueError(f"Expected 'index' to be an index (got "
                         f"'{index.dtype}')")
    if src.device != index.device:
        raise ValueError(f"Both 'src' and 'index' must be on the same device "
                         f"(got '{src.device}' and '{index.device}')")

    if max_index is None:
        max_index = torch.maximum(src.max(), index.max())

    # If the `max_index` is in a reasonable range, we can accelerate this
    # operation by creating a helper vector to perform the mapping.
    # NOTE This will potentially consumes a large chunk of memory
    # (max_index=10 million => ~75MB), so we cap it at a reasonable size:
    THRESHOLD = 40_000_000 if src.is_cuda else 10_000_000
    if max_index <= THRESHOLD:
        if inclusive:
            assoc = src.new_empty(max_index + 1)  # type: ignore
        else:
            assoc = src.new_full((max_index + 1, ), -1)  # type: ignore
        assoc[index] = torch.arange(index.numel(), dtype=src.dtype,
                                    device=src.device)
        out = assoc[src]

        if inclusive:
            return out, None
        else:
            mask = out != -1
            return out[mask], mask

    WITH_CUDF = False
    if src.is_cuda:
        try:
            import cudf
            WITH_CUDF = True
        except ImportError:
            import pandas as pd
            _warning_cache.warn("Using CPU-based processing within "
                                "'map_index' which may cause slowdowns and "
                                "device synchronization. Consider installing "
                                "'cudf' to accelerate computation")
    else:
        import pandas as pd

    if not WITH_CUDF:
        left_ser = pd.Series(src.cpu().numpy(), name='left_ser')
        right_ser = pd.Series(
            index=index.cpu().numpy(),
            data=pd.RangeIndex(0, index.size(0)),
            name='right_ser',
        )

        result = pd.merge(left_ser, right_ser, how='left', left_on='left_ser',
                          right_index=True)

        out_numpy = result['right_ser'].values
        if (index.device.type == 'mps'  # MPS does not support `float64`
                and issubclass(out_numpy.dtype.type, np.floating)):
            out_numpy = out_numpy.astype(np.float32)

        out = torch.from_numpy(out_numpy).to(index.device)

        if out.is_floating_point() and inclusive:
            raise ValueError("Found invalid entries in 'src' that do not have "
                             "a corresponding entry in 'index'. Set "
                             "`inclusive=False` to ignore these entries.")

        if out.is_floating_point():
            mask = torch.isnan(out).logical_not_()
            out = out[mask].to(index.dtype)
            return out, mask

        if inclusive:
            return out, None
        else:
            mask = out != -1
            return out[mask], mask

    else:
        left_ser = cudf.Series(src, name='left_ser')
        right_ser = cudf.Series(
            index=index,
            data=cudf.RangeIndex(0, index.size(0)),
            name='right_ser',
        )

        result = cudf.merge(left_ser, right_ser, how='left',
                            left_on='left_ser', right_index=True, sort=True)

        if inclusive:
            try:
                out = from_dlpack(result['right_ser'].to_dlpack())
            except ValueError as e:
                raise ValueError(
                    "Found invalid entries in 'src' that do not "
                    "have a corresponding entry in 'index'. Set "
                    "`inclusive=False` to ignore these entries.") from e
        else:
            out = from_dlpack(result['right_ser'].fillna(-1).to_dlpack())

        out = out[src.argsort().argsort()]  # Restore original order.

        if inclusive:
            return out, None
        else:
            mask = out != -1
            return out[mask], mask