File: autograd.rst

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (277 lines) | stat: -rw-r--r-- 8,992 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
.. role:: hidden
    :class: hidden-section

Automatic differentiation package - torch.autograd
==================================================

.. automodule:: torch.autograd
.. currentmodule:: torch.autograd

.. autosummary::
    :toctree: generated
    :nosignatures:

    backward
    grad

.. _forward-mode-ad:

Forward-mode Automatic Differentiation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. warning::
    This API is in beta. Even though the function signatures are very unlikely to change, improved
    operator coverage is planned before we consider this stable.

Please see the `forward-mode AD tutorial <https://pytorch.org/tutorials/intermediate/forward_ad_usage.html>`__
for detailed steps on how to use this API.

.. autosummary::
    :toctree: generated
    :nosignatures:

    forward_ad.dual_level
    forward_ad.make_dual
    forward_ad.unpack_dual

.. _functional-api:

Functional higher level API
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. warning::
    This API is in beta. Even though the function signatures are very unlikely to change, major
    improvements to performances are planned before we consider this stable.

This section contains the higher level API for the autograd that builds on the basic API above
and allows you to compute jacobians, hessians, etc.

This API works with user-provided functions that take only Tensors as input and return
only Tensors.
If your function takes other arguments that are not Tensors or Tensors that don't have requires_grad set,
you can use a lambda to capture them.
For example, for a function ``f`` that takes three inputs, a Tensor for which we want the jacobian, another
tensor that should be considered constant and a boolean flag as ``f(input, constant, flag=flag)``
you can use it as ``functional.jacobian(lambda x: f(x, constant, flag=flag), input)``.

.. autosummary::
    :toctree: generated
    :nosignatures:

    functional.jacobian
    functional.hessian
    functional.vjp
    functional.jvp
    functional.vhp
    functional.hvp

.. _locally-disable-grad:

Locally disabling gradient computation
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

See :ref:`locally-disable-grad-doc` for more information on the differences
between no-grad and inference mode as well as other related mechanisms that
may be confused with the two. Also see :ref:`torch-rst-local-disable-grad`
for a list of functions that can be used to locally disable gradients.

.. _default-grad-layouts:

Default gradient layouts
^^^^^^^^^^^^^^^^^^^^^^^^

When a non-sparse ``param`` receives a non-sparse gradient during
:func:`torch.autograd.backward` or :func:`torch.Tensor.backward`
``param.grad`` is accumulated as follows.

If ``param.grad`` is initially ``None``:

1. If ``param``'s memory is non-overlapping and dense, ``.grad`` is
   created with strides matching ``param`` (thus matching ``param``'s
   layout).
2. Otherwise, ``.grad`` is created with rowmajor-contiguous strides.

If ``param`` already has a non-sparse ``.grad`` attribute:

3. If ``create_graph=False``, ``backward()`` accumulates into ``.grad``
   in-place, which preserves its strides.
4. If ``create_graph=True``, ``backward()`` replaces ``.grad`` with a
   new tensor ``.grad + new grad``, which attempts (but does not guarantee)
   matching the preexisting ``.grad``'s strides.

The default behavior (letting ``.grad``\ s be ``None`` before the first
``backward()``, such that their layout is created according to 1 or 2,
and retained over time according to 3 or 4) is recommended for best performance.
Calls to ``model.zero_grad()`` or ``optimizer.zero_grad()`` will not affect ``.grad``
layouts.

In fact, resetting all ``.grad``\ s to ``None`` before each
accumulation phase, e.g.::

    for iterations...
        ...
        for param in model.parameters():
            param.grad = None
        loss.backward()

such that they're recreated according to 1 or 2 every time,
is a valid alternative to ``model.zero_grad()`` or ``optimizer.zero_grad()``
that may improve performance for some networks.

Manual gradient layouts
-----------------------

If you need manual control over ``.grad``'s strides,
assign ``param.grad =`` a zeroed tensor with desired strides
before the first ``backward()``, and never reset it to ``None``.
3 guarantees your layout is preserved as long as ``create_graph=False``.
4 indicates your layout is *likely* preserved even if ``create_graph=True``.

In-place operations on Tensors
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Supporting in-place operations in autograd is a hard matter, and we discourage
their use in most cases. Autograd's aggressive buffer freeing and reuse makes
it very efficient and there are very few occasions when in-place operations
actually lower memory usage by any significant amount. Unless you're operating
under heavy memory pressure, you might never need to use them.

In-place correctness checks
---------------------------

All :class:`Tensor` s keep track of in-place operations applied to them, and
if the implementation detects that a tensor was saved for backward in one of
the functions, but it was modified in-place afterwards, an error will be raised
once backward pass is started. This ensures that if you're using in-place
functions and not seeing any errors, you can be sure that the computed
gradients are correct.

Variable (deprecated)
^^^^^^^^^^^^^^^^^^^^^

.. warning::
    The Variable API has been deprecated: Variables are no longer necessary to
    use autograd with tensors. Autograd automatically supports Tensors with
    ``requires_grad`` set to ``True``. Below please find a quick guide on what
    has changed:

    - ``Variable(tensor)`` and ``Variable(tensor, requires_grad)`` still work as expected,
      but they return Tensors instead of Variables.
    - ``var.data`` is the same thing as ``tensor.data``.
    - Methods such as ``var.backward(), var.detach(), var.register_hook()`` now work on tensors
      with the same method names.

    In addition, one can now create tensors with ``requires_grad=True`` using factory
    methods such as :func:`torch.randn`, :func:`torch.zeros`, :func:`torch.ones`, and others
    like the following:

    ``autograd_tensor = torch.randn((2, 3, 4), requires_grad=True)``

Tensor autograd functions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autosummary::
    :nosignatures:

   torch.Tensor.grad
   torch.Tensor.requires_grad
   torch.Tensor.is_leaf
   torch.Tensor.backward
   torch.Tensor.detach
   torch.Tensor.detach_
   torch.Tensor.register_hook
   torch.Tensor.retain_grad

:hidden:`Function`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: Function

.. autosummary::
    :toctree: generated
    :nosignatures:

    Function.forward
    Function.backward
    Function.jvp

Context method mixins
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
When creating a new :class:`Function`, the following methods are available to `ctx`.

.. autosummary::
    :toctree: generated
    :nosignatures:

    function.FunctionCtx.mark_dirty
    function.FunctionCtx.mark_non_differentiable
    function.FunctionCtx.save_for_backward
    function.FunctionCtx.set_materialize_grads

.. _grad-check:

Numerical gradient checking
^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autosummary::
    :toctree: generated
    :nosignatures:

    gradcheck
    gradgradcheck

Profiler
^^^^^^^^

Autograd includes a profiler that lets you inspect the cost of different
operators inside your model - both on the CPU and GPU. There are three modes
implemented at the moment - CPU-only using :class:`~torch.autograd.profiler.profile`.
nvprof based (registers both CPU and GPU activity) using
:class:`~torch.autograd.profiler.emit_nvtx`.
and vtune profiler based using
:class:`~torch.autograd.profiler.emit_itt`.

.. autoclass:: torch.autograd.profiler.profile

.. autosummary::
    :toctree: generated
    :nosignatures:

    profiler.profile.export_chrome_trace
    profiler.profile.key_averages
    profiler.profile.self_cpu_time_total
    profiler.profile.total_average

.. autoclass:: torch.autograd.profiler.emit_nvtx
.. autoclass:: torch.autograd.profiler.emit_itt


.. autosummary::
    :toctree: generated
    :nosignatures:

    profiler.load_nvprof

Anomaly detection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: detect_anomaly

.. autoclass:: set_detect_anomaly


Saved tensors default hooks
^^^^^^^^^^^^^^^^^^^^^^^^^^^

Some operations need intermediary results to be saved during the forward pass
in order to execute the backward pass.
You can define how these saved tensors should be packed / unpacked using hooks.
A common application is to trade compute for memory by saving those intermediary results
to disk or to CPU instead of leaving them on the GPU. This is especially useful if you
notice your model fits on GPU during evaluation, but not training.
Also see :ref:`saved-tensors-hooks-doc`.

.. autoclass:: torch.autograd.graph.saved_tensors_hooks

.. autoclass:: torch.autograd.graph.save_on_cpu

.. autoclass:: torch.autograd.graph.disable_saved_tensors_hooks