1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
|
Tensor Indexing API
===================
Indexing a tensor in the PyTorch C++ API works very similar to the Python API.
All index types such as ``None`` / ``...`` / integer / boolean / slice / tensor
are available in the C++ API, making translation from Python indexing code to C++
very simple. The main difference is that, instead of using the ``[]``-operator
similar to the Python API syntax, in the C++ API the indexing methods are:
- ``torch::Tensor::index`` (`link <https://pytorch.org/cppdocs/api/classat_1_1_tensor.html#_CPPv4NK2at6Tensor5indexE8ArrayRefIN2at8indexing11TensorIndexEE>`_)
- ``torch::Tensor::index_put_`` (`link <https://pytorch.org/cppdocs/api/classat_1_1_tensor.html#_CPPv4N2at6Tensor10index_put_E8ArrayRefIN2at8indexing11TensorIndexEERK6Tensor>`_)
It's also important to note that index types such as ``None`` / ``Ellipsis`` / ``Slice``
live in the ``torch::indexing`` namespace, and it's recommended to put ``using namespace torch::indexing``
before any indexing code for convenient use of those index types.
Here are some examples of translating Python indexing code to C++:
Getter
------
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| Python | C++ (assuming ``using namespace torch::indexing``) |
+==========================================================+======================================================================================+
| ``tensor[None]`` | ``tensor.index({None})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[Ellipsis, ...]`` | ``tensor.index({Ellipsis, "..."})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1, 2]`` | ``tensor.index({1, 2})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[True, False]`` | ``tensor.index({true, false})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1::2]`` | ``tensor.index({Slice(1, None, 2)})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[torch.tensor([1, 2])]`` | ``tensor.index({torch::tensor({1, 2})})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])]`` | ``tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
Setter
------
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| Python | C++ (assuming ``using namespace torch::indexing``) |
+==========================================================+======================================================================================+
| ``tensor[None] = 1`` | ``tensor.index_put_({None}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[Ellipsis, ...] = 1`` | ``tensor.index_put_({Ellipsis, "..."}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1, 2] = 1`` | ``tensor.index_put_({1, 2}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[True, False] = 1`` | ``tensor.index_put_({true, false}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1::2] = 1`` | ``tensor.index_put_({Slice(1, None, 2)}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({torch::tensor({1, 2})}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
Translating between Python/C++ index types
------------------------------------------
The one-to-one translation between Python and C++ index types is as follows:
+-------------------------+------------------------------------------------------------------------+
| Python | C++ (assuming ``using namespace torch::indexing``) |
+=========================+========================================================================+
| ``None`` | ``None`` |
+-------------------------+------------------------------------------------------------------------+
| ``Ellipsis`` | ``Ellipsis`` |
+-------------------------+------------------------------------------------------------------------+
| ``...`` | ``"..."`` |
+-------------------------+------------------------------------------------------------------------+
| ``123`` | ``123`` |
+-------------------------+------------------------------------------------------------------------+
| ``True`` | ``true`` |
+-------------------------+------------------------------------------------------------------------+
| ``False`` | ``false`` |
+-------------------------+------------------------------------------------------------------------+
| ``:`` or ``::`` | ``Slice()`` or ``Slice(None, None)`` or ``Slice(None, None, None)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1:`` or ``1::`` | ``Slice(1, None)`` or ``Slice(1, None, None)`` |
+-------------------------+------------------------------------------------------------------------+
| ``:3`` or ``:3:`` | ``Slice(None, 3)`` or ``Slice(None, 3, None)`` |
+-------------------------+------------------------------------------------------------------------+
| ``::2`` | ``Slice(None, None, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1:3`` | ``Slice(1, 3)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1::2`` | ``Slice(1, None, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``:3:2`` | ``Slice(None, 3, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``1:3:2`` | ``Slice(1, 3, 2)`` |
+-------------------------+------------------------------------------------------------------------+
| ``torch.tensor([1, 2])``| ``torch::tensor({1, 2})`` |
+-------------------------+------------------------------------------------------------------------+
|