1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
from typing import Optional
import torch
def radius(
x: torch.Tensor,
y: torch.Tensor,
r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
y (Tensor): Node feature matrix
:math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`.
r (float): The radius.
batch_x (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. :obj:`batch_x` needs to be sorted.
(default: :obj:`None`)
batch_y (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
node to a specific example. :obj:`batch_y` needs to be sorted.
(default: :obj:`None`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element in :obj:`y`.
If the number of actual neighbors is greater than
:obj:`max_num_neighbors`, returned neighbors are picked randomly.
(default: :obj:`32`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python
import torch
from torch_cluster import radius
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch_x = torch.tensor([0, 0, 0, 0])
y = torch.Tensor([[-1, 0], [1, 0]])
batch_y = torch.tensor([0, 0])
assign_index = radius(x, y, 1.5, batch_x, batch_y)
"""
if x.numel() == 0 or y.numel() == 0:
return torch.empty(2, 0, dtype=torch.long, device=x.device)
x = x.view(-1, 1) if x.dim() == 1 else x
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0
ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
arange = torch.arange(batch_size + 1, device=x.device)
ptr_x = torch.bucketize(arange, batch_x)
ptr_y = torch.bucketize(arange, batch_y)
return torch.ops.torch_cluster.radius(x, y, ptr_x, ptr_y, r,
max_num_neighbors, num_workers)
def radius_graph(
x: torch.Tensor,
r: float,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
r (float): The radius.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. :obj:`batch` needs to be sorted.
(default: :obj:`None`)
loop (bool, optional): If :obj:`True`, the graph will contain
self-loops. (default: :obj:`False`)
max_num_neighbors (int, optional): The maximum number of neighbors to
return for each element.
If the number of actual neighbors is greater than
:obj:`max_num_neighbors`, returned neighbors are picked randomly.
(default: :obj:`32`)
flow (string, optional): The flow direction when used in combination
with message passing (:obj:`"source_to_target"` or
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
from torch_cluster import radius_graph
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)
"""
assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
row, col = edge_index[0], edge_index[1]
if not loop:
mask = row != col
row, col = row[mask], col[mask]
return torch.stack([row, col], dim=0)
|