1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
|
from typing import Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch.utils.dlpack import from_dlpack
from torch_geometric.warnings import WarningCache
_warning_cache = WarningCache()
def map_index(
src: Tensor,
index: Tensor,
max_index: Optional[Union[int, Tensor]] = None,
inclusive: bool = False,
) -> Tuple[Tensor, Optional[Tensor]]:
r"""Maps indices in :obj:`src` to the positional value of their
corresponding occurrence in :obj:`index`.
Indices must be strictly positive.
Args:
src (torch.Tensor): The source tensor to map.
index (torch.Tensor): The index tensor that denotes the new mapping.
max_index (int, optional): The maximum index value.
(default :obj:`None`)
inclusive (bool, optional): If set to :obj:`True`, it is assumed that
every entry in :obj:`src` has a valid entry in :obj:`index`.
Can speed-up computation. (default: :obj:`False`)
:rtype: (:class:`torch.Tensor`, :class:`torch.BoolTensor`)
Examples:
>>> src = torch.tensor([2, 0, 1, 0, 3])
>>> index = torch.tensor([3, 2, 0, 1])
>>> map_index(src, index)
(tensor([1, 2, 3, 2, 0]), tensor([True, True, True, True, True]))
>>> src = torch.tensor([2, 0, 1, 0, 3])
>>> index = torch.tensor([3, 2, 0])
>>> map_index(src, index)
(tensor([1, 2, 2, 0]), tensor([True, True, False, True, True]))
.. note::
If inputs are on GPU and :obj:`cudf` is available, consider using RMM
for significant speed boosts.
Proceed with caution as RMM may conflict with other allocators or
fragments.
.. code-block:: python
import rmm
rmm.reinitialize(pool_allocator=True)
torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)
"""
if src.is_floating_point():
raise ValueError(f"Expected 'src' to be an index (got '{src.dtype}')")
if index.is_floating_point():
raise ValueError(f"Expected 'index' to be an index (got "
f"'{index.dtype}')")
if src.device != index.device:
raise ValueError(f"Both 'src' and 'index' must be on the same device "
f"(got '{src.device}' and '{index.device}')")
if max_index is None:
max_index = torch.maximum(src.max(), index.max())
# If the `max_index` is in a reasonable range, we can accelerate this
# operation by creating a helper vector to perform the mapping.
# NOTE This will potentially consumes a large chunk of memory
# (max_index=10 million => ~75MB), so we cap it at a reasonable size:
THRESHOLD = 40_000_000 if src.is_cuda else 10_000_000
if max_index <= THRESHOLD:
if inclusive:
assoc = src.new_empty(max_index + 1) # type: ignore
else:
assoc = src.new_full((max_index + 1, ), -1) # type: ignore
assoc[index] = torch.arange(index.numel(), dtype=src.dtype,
device=src.device)
out = assoc[src]
if inclusive:
return out, None
else:
mask = out != -1
return out[mask], mask
WITH_CUDF = False
if src.is_cuda:
try:
import cudf
WITH_CUDF = True
except ImportError:
import pandas as pd
_warning_cache.warn("Using CPU-based processing within "
"'map_index' which may cause slowdowns and "
"device synchronization. Consider installing "
"'cudf' to accelerate computation")
else:
import pandas as pd
if not WITH_CUDF:
left_ser = pd.Series(src.cpu().numpy(), name='left_ser')
right_ser = pd.Series(
index=index.cpu().numpy(),
data=pd.RangeIndex(0, index.size(0)),
name='right_ser',
)
result = pd.merge(left_ser, right_ser, how='left', left_on='left_ser',
right_index=True)
out_numpy = result['right_ser'].values
if (index.device.type == 'mps' # MPS does not support `float64`
and issubclass(out_numpy.dtype.type, np.floating)):
out_numpy = out_numpy.astype(np.float32)
out = torch.from_numpy(out_numpy).to(index.device)
if out.is_floating_point() and inclusive:
raise ValueError("Found invalid entries in 'src' that do not have "
"a corresponding entry in 'index'. Set "
"`inclusive=False` to ignore these entries.")
if out.is_floating_point():
mask = torch.isnan(out).logical_not_()
out = out[mask].to(index.dtype)
return out, mask
if inclusive:
return out, None
else:
mask = out != -1
return out[mask], mask
else:
left_ser = cudf.Series(src, name='left_ser')
right_ser = cudf.Series(
index=index,
data=cudf.RangeIndex(0, index.size(0)),
name='right_ser',
)
result = cudf.merge(left_ser, right_ser, how='left',
left_on='left_ser', right_index=True, sort=True)
if inclusive:
try:
out = from_dlpack(result['right_ser'].to_dlpack())
except ValueError as e:
raise ValueError(
"Found invalid entries in 'src' that do not "
"have a corresponding entry in 'index'. Set "
"`inclusive=False` to ignore these entries.") from e
else:
out = from_dlpack(result['right_ser'].fillna(-1).to_dlpack())
out = out[src.argsort().argsort()] # Restore original order.
if inclusive:
return out, None
else:
mask = out != -1
return out[mask], mask
|