File: tensor_indexing.rst

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (99 lines) | stat: -rw-r--r-- 9,669 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
Tensor Indexing API
===================

Indexing a tensor in the PyTorch C++ API works very similar to the Python API.
All index types such as ``None`` / ``...`` / integer / boolean / slice / tensor
are available in the C++ API, making translation from Python indexing code to C++
very simple. The main difference is that, instead of using the ``[]``-operator
similar to the Python API syntax, in the C++ API the indexing methods are:

- ``torch::Tensor::index`` (`link <https://pytorch.org/cppdocs/api/classat_1_1_tensor.html#_CPPv4NK2at6Tensor5indexE8ArrayRefIN2at8indexing11TensorIndexEE>`_)
- ``torch::Tensor::index_put_`` (`link <https://pytorch.org/cppdocs/api/classat_1_1_tensor.html#_CPPv4N2at6Tensor10index_put_E8ArrayRefIN2at8indexing11TensorIndexEERK6Tensor>`_)

It's also important to note that index types such as ``None`` / ``Ellipsis`` / ``Slice``
live in the ``torch::indexing`` namespace, and it's recommended to put ``using namespace torch::indexing``
before any indexing code for convenient use of those index types.

Here are some examples of translating Python indexing code to C++:

Getter
------

+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| Python                                                   | C++  (assuming ``using namespace torch::indexing``)                                  |
+==========================================================+======================================================================================+
| ``tensor[None]``                                         | ``tensor.index({None})``                                                             |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[Ellipsis, ...]``                                | ``tensor.index({Ellipsis, "..."})``                                                  |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1, 2]``                                         | ``tensor.index({1, 2})``                                                             |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[True, False]``                                  | ``tensor.index({true, false})``                                                      |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1::2]``                                         | ``tensor.index({Slice(1, None, 2)})``                                                |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[torch.tensor([1, 2])]``                         | ``tensor.index({torch::tensor({1, 2})})``                                            |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])]``     | ``tensor.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})``         |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+

Setter
------

+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| Python                                                   | C++  (assuming ``using namespace torch::indexing``)                                  |
+==========================================================+======================================================================================+
| ``tensor[None] = 1``                                     | ``tensor.index_put_({None}, 1)``                                                     |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[Ellipsis, ...] = 1``                            | ``tensor.index_put_({Ellipsis, "..."}, 1)``                                          |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1, 2] = 1``                                     | ``tensor.index_put_({1, 2}, 1)``                                                     |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[True, False] = 1``                              | ``tensor.index_put_({true, false}, 1)``                                              |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[1::2] = 1``                                     | ``tensor.index_put_({Slice(1, None, 2)}, 1)``                                        |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[torch.tensor([1, 2])] = 1``                     | ``tensor.index_put_({torch::tensor({1, 2})}, 1)``                                    |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+
| ``tensor[..., 0, True, 1::2, torch.tensor([1, 2])] = 1`` | ``tensor.index_put_({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}, 1)`` |
+----------------------------------------------------------+--------------------------------------------------------------------------------------+


Translating between Python/C++ index types
------------------------------------------

The one-to-one translation between Python and C++ index types is as follows:

+-------------------------+------------------------------------------------------------------------+
| Python                  | C++ (assuming ``using namespace torch::indexing``)                     |
+=========================+========================================================================+
| ``None``                | ``None``                                                               |
+-------------------------+------------------------------------------------------------------------+
| ``Ellipsis``            | ``Ellipsis``                                                           |
+-------------------------+------------------------------------------------------------------------+
| ``...``                 | ``"..."``                                                              |
+-------------------------+------------------------------------------------------------------------+
| ``123``                 | ``123``                                                                |
+-------------------------+------------------------------------------------------------------------+
| ``True``                | ``true``                                                               |
+-------------------------+------------------------------------------------------------------------+
| ``False``               | ``false``                                                              |
+-------------------------+------------------------------------------------------------------------+
| ``:`` or ``::``         | ``Slice()`` or ``Slice(None, None)`` or ``Slice(None, None, None)``    |
+-------------------------+------------------------------------------------------------------------+
| ``1:`` or ``1::``       | ``Slice(1, None)`` or ``Slice(1, None, None)``                         |
+-------------------------+------------------------------------------------------------------------+
| ``:3`` or ``:3:``       | ``Slice(None, 3)`` or ``Slice(None, 3, None)``                         |
+-------------------------+------------------------------------------------------------------------+
| ``::2``                 | ``Slice(None, None, 2)``                                               |
+-------------------------+------------------------------------------------------------------------+
| ``1:3``                 | ``Slice(1, 3)``                                                        |
+-------------------------+------------------------------------------------------------------------+
| ``1::2``                | ``Slice(1, None, 2)``                                                  |
+-------------------------+------------------------------------------------------------------------+
| ``:3:2``                | ``Slice(None, 3, 2)``                                                  |
+-------------------------+------------------------------------------------------------------------+
| ``1:3:2``               | ``Slice(1, 3, 2)``                                                     |
+-------------------------+------------------------------------------------------------------------+
| ``torch.tensor([1, 2])``| ``torch::tensor({1, 2})``                                              |
+-------------------------+------------------------------------------------------------------------+