1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
|
torch.nested
============
.. automodule:: torch.nested
Introduction
++++++++++++
.. warning::
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
The only constraint on the input Tensors is that their dimension must match.
This enables more efficient metadata representations and access to purpose built kernels.
One application of NestedTensors is to express sequential data in various domains.
While the conventional approach is to pad variable length sequences, NestedTensor
enables users to bypass padding. The API for calling operations on a nested tensor is no different
from that of a regular ``torch.Tensor``, which should allow seamless integration with existing models,
with the main difference being :ref:`construction of the inputs <construction>`.
As this is a prototype feature, the :ref:`operations supported <supported operations>` are still
limited. However, we welcome issues, feature requests and contributions. More information on contributing can be found
`on this wiki <https://github.com/pytorch/pytorch/wiki/NestedTensor-Backend>`_.
.. _construction:
Construction
++++++++++++
Construction is straightforward and involves passing a list of Tensors to the ``torch.nested.nested_tensor``
constructor.
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested.nested_tensor([a, b])
>>> nt
nested_tensor([
tensor([0, 1, 2]),
tensor([3, 4, 5, 6, 7])
])
Data type, device and whether gradients are required can be chosen via the usual keyword arguments.
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32, device="cuda", requires_grad=True)
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0', requires_grad=True),
tensor([3., 4., 5., 6., 7.], device='cuda:0', requires_grad=True)
], device='cuda:0', requires_grad=True)
In the vein of ``torch.as_tensor``, ``torch.nested.as_nested_tensor`` can be used to preserve autograd
history from the tensors passed to the constructor. For more information, refer to the section on
:ref:`constructor functions`.
In order to form a valid NestedTensor all the passed Tensors need to match in dimension, but none of the other attributes need to.
>>> a = torch.randn(3, 50, 70) # image 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt.dim()
4
If one of the dimensions doesn't match, the constructor throws an error.
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(3, 128, 64) # image 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: All Tensors given to nested_tensor must have the same dimension. Found dimension 3 for Tensor at index 1 and dimension 2 for Tensor at index 0.
Note that the passed Tensors are being copied into a contiguous piece of memory. The resulting
NestedTensor allocates new memory to store them and does not keep a reference.
At this moment we only support one level of nesting, i.e. a simple, flat list of Tensors. In the future
we can add support for multiple levels of nesting, such as a list that consists entirely of lists of Tensors.
Note that for this extension it is important to maintain an even level of nesting across entries so that the resulting NestedTensor
has a well defined dimension. If you have a need for this feature, please feel encouraged to open a feature request so that
we can track it and plan accordingly.
size
+++++++++++++++++++++++++
Even though a NestedTensor does not support ``.size()`` (or ``.shape``), it supports ``.size(i)`` if dimension i is regular.
>>> a = torch.randn(50, 128) # text 1
>>> b = torch.randn(32, 128) # text 2
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Given dimension 1 is irregular and does not have a size.
>>> nt.size(2)
128
If all dimensions are regular, the NestedTensor is intended to be semantically indistinguishable from a regular ``torch.Tensor``.
>>> a = torch.randn(20, 128) # text 1
>>> nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)
>>> nt.size(0)
2
>>> nt.size(1)
20
>>> nt.size(2)
128
>>> torch.stack(nt.unbind()).size()
torch.Size([2, 20, 128])
>>> torch.stack([a, a]).size()
torch.Size([2, 20, 128])
>>> torch.equal(torch.stack(nt.unbind()), torch.stack([a, a]))
True
In the future we might make it easier to detect this condition and convert seamlessly.
Please open a feature request if you have a need for this (or any other related feature for that matter).
unbind
+++++++++++++++++++++++++
``unbind`` allows you to retrieve a view of the constituents.
>>> import torch
>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> nt = torch.nested.nested_tensor([a, b], dtype=torch.float32)
>>> nt
nested_tensor([
tensor([[ 1.2286, -1.2343, -1.4842],
[-0.7827, 0.6745, 0.0658]]),
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]])
])
>>> nt.unbind()
(tensor([[ 1.2286, -1.2343, -1.4842],
[-0.7827, 0.6745, 0.0658]]), tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]]))
>>> nt.unbind()[0] is not a
True
>>> nt.unbind()[0].mul_(3)
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]])
>>> nt
nested_tensor([
tensor([[ 3.6858, -3.7030, -4.4525],
[-2.3481, 2.0236, 0.1975]]),
tensor([[-1.1247, -0.4078, -1.0633, 0.8083],
[-0.2871, -0.2980, 0.5559, 1.9885],
[ 0.4074, 2.4855, 0.0733, 0.8285]])
])
Note that ``nt.unbind()[0]`` is not a copy, but rather a slice of the underlying memory, which represents the first entry or constituent of the NestedTensor.
.. _constructor functions:
Nested tensor constructor and conversion functions
++++++++++++++++++++++++++++++++++++++++++++++++++
The following functions are related to nested tensors:
.. currentmodule:: torch.nested
.. autofunction:: nested_tensor
.. autofunction:: as_nested_tensor
.. autofunction:: to_padded_tensor
.. _supported operations:
Supported operations
++++++++++++++++++++++++++
In this section, we summarize the operations that are currently supported on
NestedTensor and any constraints they have.
.. csv-table::
:header: "PyTorch operation", "Constraints"
:widths: 30, 55
:delim: ;
:func:`torch.matmul`; "Supports matrix multiplication between two (>= 3d) nested tensors where
the last two dimensions are matrix dimensions and the leading (batch) dimensions have the same size
(i.e. no broadcasting support for batch dimensions yet)."
:func:`torch.bmm`; "Supports batch matrix multiplication of two 3-d nested tensors."
:func:`torch.nn.Linear`; "Supports 3-d nested input and a dense 2-d weight matrix."
:func:`torch.nn.functional.softmax`; "Supports softmax along all dims except dim=0."
:func:`torch.nn.Dropout`; "Behavior is the same as on regular tensors."
:func:`torch.relu`; "Behavior is the same as on regular tensors."
:func:`torch.gelu`; "Behavior is the same as on regular tensors."
:func:`torch.add`; "Supports elementwise addition of two nested tensors.
Supports addition of a scalar to a nested tensor."
:func:`torch.mul`; "Supports elementwise multiplication of two nested tensors.
Supports multipication of a nested tensor by a scalar."
:func:`torch.select`; "Supports selecting along ``dim=0`` only (analogously ``nt[i]``)."
:func:`torch.clone`; "Behavior is the same as on regular tensors."
:func:`torch.detach`; "Behavior is the same as on regular tensors."
:func:`torch.unbind`; "Supports unbinding along ``dim=0`` only."
:func:`torch.reshape`; "Supports reshaping with size of ``dim=0`` preserved (i.e. number of tensors nested cannot be changed).
Unlike regular tensors, a size of ``-1`` here means that the existing size is inherited.
In particular, the only valid size for a ragged dimension is ``-1``.
Size inference is not implemented yet and hence for new dimensions the size cannot be ``-1``."
:func:`torch.Tensor.reshape_as`; "Similar constraint as for ``reshape``."
:func:`torch.transpose`; "Supports transposing of all dims except ``dim=0``."
:func:`torch.Tensor.view`; "Rules for the new shape are similar to that of ``reshape``."
|