1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
|
from typing import Optional, Tuple
import torch
from .utils import broadcast
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return scatter_sum(src, index, dim, out, dim_size)
def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode='floor')
return out
def scatter_min(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
def scatter_max(
src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/add.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1`, although no specific ordering of indices is required.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="sum"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
:param src: The source tensor.
:param index: The indices of elements to scatter.
:param dim: The axis along which to index. (default: :obj:`-1`)
:param out: The destination tensor.
:param dim_size: If :attr:`out` is not given, automatically create output
with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
:param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)
:rtype: :class:`Tensor`
.. code-block:: python
from torch_scatter import scatter
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 1, 0, 1, 2, 1])
# Broadcasting in the first and last dim.
out = scatter(src, index, dim=1, reduce="sum")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
if reduce == 'mul':
return scatter_mul(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
|