File: amp.rst

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (376 lines) | stat: -rw-r--r-- 11,278 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
.. role:: hidden
    :class: hidden-section

Automatic Mixed Precision package - torch.amp
=============================================

.. Both modules below are missing doc entry. Adding them here for now.
.. This does not add anything to the rendered page
.. py:module:: torch.cpu
.. py:module:: torch.cpu.amp
.. py:module:: torch.cuda.amp

.. automodule:: torch.amp
.. currentmodule:: torch.amp

:class:`torch.amp` provides convenience methods for mixed precision,
where some operations use the ``torch.float32`` (``float``) datatype and other operations
use lower precision floating point datatype (``lower_precision_fp``): ``torch.float16`` (``half``) or ``torch.bfloat16``. Some ops, like linear layers and convolutions,
are much faster in ``lower_precision_fp``. Other ops, like reductions, often require the dynamic
range of ``float32``.  Mixed precision tries to match each op to its appropriate datatype.

Ordinarily, "automatic mixed precision training" with datatype of ``torch.float16`` uses :class:`torch.autocast` and
:class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`CUDA Automatic Mixed Precision examples<amp-examples>`
and `CUDA Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
However, :class:`torch.autocast` and :class:`torch.cuda.amp.GradScaler` are modular, and may be used separately if desired.
As shown in the CPU example section of :class:`torch.autocast`, "automatic mixed precision training/inference" on CPU with
datatype of ``torch.bfloat16`` only uses :class:`torch.autocast`.

For CUDA and CPU, APIs are also provided separately:

* ``torch.autocast("cuda", args...)`` is equivalent to ``torch.cuda.amp.autocast(args...)``.
* ``torch.autocast("cpu", args...)`` is equivalent to ``torch.cpu.amp.autocast(args...)``. For CPU, only lower precision floating point datatype of ``torch.bfloat16`` is supported for now.

.. contents:: :local:

.. _autocasting:

Autocasting
^^^^^^^^^^^
.. currentmodule:: torch

.. autoclass:: autocast
    :members:

.. currentmodule:: torch.cuda.amp

.. autoclass:: autocast
    :members:

.. autofunction::  custom_fwd

.. autofunction::  custom_bwd

.. currentmodule:: torch.cpu.amp

.. autoclass:: autocast
    :members:

.. _gradient-scaling:

Gradient Scaling
^^^^^^^^^^^^^^^^

If the forward pass for a particular op has ``float16`` inputs, the backward pass for
that op will produce ``float16`` gradients.
Gradient values with small magnitudes may not be representable in ``float16``.
These values will flush to zero ("underflow"), so the update for the corresponding parameters will be lost.

To prevent underflow, "gradient scaling" multiplies the network's loss(es) by a scale factor and
invokes a backward pass on the scaled loss(es).  Gradients flowing backward through the network are
then scaled by the same factor.  In other words, gradient values have a larger magnitude,
so they don't flush to zero.

Each parameter's gradient (``.grad`` attribute) should be unscaled before the optimizer
updates the parameters, so the scale factor does not interfere with the learning rate.

.. currentmodule:: torch.cuda.amp

.. autoclass:: GradScaler
    :members:

.. _autocast-op-reference:

Autocast Op Reference
^^^^^^^^^^^^^^^^^^^^^

.. _autocast-eligibility:

Op Eligibility
--------------
Ops that run in ``float64`` or non-floating-point dtypes are not eligible, and will
run in these types whether or not autocast is enabled.

Only out-of-place ops and Tensor methods are eligible.
In-place variants and calls that explicitly supply an ``out=...`` Tensor
are allowed in autocast-enabled regions, but won't go through autocasting.
For example, in an autocast-enabled region ``a.addmm(b, c)`` can autocast,
but ``a.addmm_(b, c)`` and ``a.addmm(b, c, out=d)`` cannot.
For best performance and stability, prefer out-of-place ops in autocast-enabled
regions.

Ops called with an explicit ``dtype=...`` argument are not eligible,
and will produce output that respects the ``dtype`` argument.

.. _autocast-cuda-op-reference:

CUDA Op-Specific Behavior
-------------------------
The following lists describe the behavior of eligible ops in autocast-enabled regions.
These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`,
as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces,
they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting.  They run in the type
defined by their inputs.  However, autocasting may still change the type
in which unlisted ops run if they're downstream from autocasted ops.

If an op is unlisted, we assume it's numerically stable in ``float16``.
If you believe an unlisted op is numerically unstable in ``float16``,
please file an issue.

CUDA Ops that can autocast to ``float16``
"""""""""""""""""""""""""""""""""""""""""

``__matmul__``,
``addbmm``,
``addmm``,
``addmv``,
``addr``,
``baddbmm``,
``bmm``,
``chain_matmul``,
``multi_dot``,
``conv1d``,
``conv2d``,
``conv3d``,
``conv_transpose1d``,
``conv_transpose2d``,
``conv_transpose3d``,
``GRUCell``,
``linear``,
``LSTMCell``,
``matmul``,
``mm``,
``mv``,
``prelu``,
``RNNCell``

CUDA Ops that can autocast to ``float32``
"""""""""""""""""""""""""""""""""""""""""

``__pow__``,
``__rdiv__``,
``__rpow__``,
``__rtruediv__``,
``acos``,
``asin``,
``binary_cross_entropy_with_logits``,
``cosh``,
``cosine_embedding_loss``,
``cdist``,
``cosine_similarity``,
``cross_entropy``,
``cumprod``,
``cumsum``,
``dist``,
``erfinv``,
``exp``,
``expm1``,
``group_norm``,
``hinge_embedding_loss``,
``kl_div``,
``l1_loss``,
``layer_norm``,
``log``,
``log_softmax``,
``log10``,
``log1p``,
``log2``,
``margin_ranking_loss``,
``mse_loss``,
``multilabel_margin_loss``,
``multi_margin_loss``,
``nll_loss``,
``norm``,
``normalize``,
``pdist``,
``poisson_nll_loss``,
``pow``,
``prod``,
``reciprocal``,
``rsqrt``,
``sinh``,
``smooth_l1_loss``,
``soft_margin_loss``,
``softmax``,
``softmin``,
``softplus``,
``sum``,
``renorm``,
``tan``,
``triplet_margin_loss``

CUDA Ops that promote to the widest input type
""""""""""""""""""""""""""""""""""""""""""""""
These ops don't require a particular dtype for stability, but take multiple inputs
and require that the inputs' dtypes match.  If all of the inputs are
``float16``, the op runs in ``float16``.  If any of the inputs is ``float32``,
autocast casts all inputs to ``float32`` and runs the op in ``float32``.

``addcdiv``,
``addcmul``,
``atan2``,
``bilinear``,
``cross``,
``dot``,
``grid_sample``,
``index_put``,
``scatter_add``,
``tensordot``

Some ops not listed here (e.g., binary ops like ``add``) natively promote
inputs without autocasting's intervention.  If inputs are a mixture of ``float16``
and ``float32``, these ops run in ``float32`` and produce ``float32`` output,
regardless of whether autocast is enabled.

Prefer ``binary_cross_entropy_with_logits`` over ``binary_cross_entropy``
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
The backward passes of :func:`torch.nn.functional.binary_cross_entropy` (and :mod:`torch.nn.BCELoss`, which wraps it)
can produce gradients that aren't representable in ``float16``.  In autocast-enabled regions, the forward input
may be ``float16``, which means the backward gradient must be representable in ``float16`` (autocasting ``float16``
forward inputs to ``float32`` doesn't help, because that cast must be reversed in backward).
Therefore, ``binary_cross_entropy`` and ``BCELoss`` raise an error in autocast-enabled regions.

Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using :func:`torch.nn.functional.binary_cross_entropy_with_logits`
or :mod:`torch.nn.BCEWithLogitsLoss`.  ``binary_cross_entropy_with_logits`` and ``BCEWithLogits``
are safe to autocast.

.. _autocast-cpu-op-reference:

CPU Op-Specific Behavior
------------------------
The following lists describe the behavior of eligible ops in autocast-enabled regions.
These ops always go through autocasting whether they are invoked as part of a :class:`torch.nn.Module`,
as a function, or as a :class:`torch.Tensor` method. If functions are exposed in multiple namespaces,
they go through autocasting regardless of the namespace.

Ops not listed below do not go through autocasting.  They run in the type
defined by their inputs.  However, autocasting may still change the type
in which unlisted ops run if they're downstream from autocasted ops.

If an op is unlisted, we assume it's numerically stable in ``bfloat16``.
If you believe an unlisted op is numerically unstable in ``bfloat16``,
please file an issue.

CPU Ops that can autocast to ``bfloat16``
"""""""""""""""""""""""""""""""""""""""""

``conv1d``,
``conv2d``,
``conv3d``,
``bmm``,
``mm``,
``baddbmm``,
``addmm``,
``addbmm``,
``linear``,
``matmul``,
``_convolution``

CPU Ops that can autocast to ``float32``
""""""""""""""""""""""""""""""""""""""""

``conv_transpose1d``,
``conv_transpose2d``,
``conv_transpose3d``,
``avg_pool3d``,
``binary_cross_entropy``,
``grid_sampler``,
``grid_sampler_2d``,
``_grid_sampler_2d_cpu_fallback``,
``grid_sampler_3d``,
``polar``,
``prod``,
``quantile``,
``nanquantile``,
``stft``,
``cdist``,
``trace``,
``view_as_complex``,
``cholesky``,
``cholesky_inverse``,
``cholesky_solve``,
``inverse``,
``lu_solve``,
``orgqr``,
``inverse``,
``ormqr``,
``pinverse``,
``max_pool3d``,
``max_unpool2d``,
``max_unpool3d``,
``adaptive_avg_pool3d``,
``reflection_pad1d``,
``reflection_pad2d``,
``replication_pad1d``,
``replication_pad2d``,
``replication_pad3d``,
``mse_loss``,
``ctc_loss``,
``kl_div``,
``multilabel_margin_loss``,
``fft_fft``,
``fft_ifft``,
``fft_fft2``,
``fft_ifft2``,
``fft_fftn``,
``fft_ifftn``,
``fft_rfft``,
``fft_irfft``,
``fft_rfft2``,
``fft_irfft2``,
``fft_rfftn``,
``fft_irfftn``,
``fft_hfft``,
``fft_ihfft``,
``linalg_matrix_norm``,
``linalg_cond``,
``linalg_matrix_rank``,
``linalg_solve``,
``linalg_cholesky``,
``linalg_svdvals``,
``linalg_eigvals``,
``linalg_eigvalsh``,
``linalg_inv``,
``linalg_householder_product``,
``linalg_tensorinv``,
``linalg_tensorsolve``,
``fake_quantize_per_tensor_affine``,
``eig``,
``geqrf``,
``lstsq``,
``_lu_with_info``,
``qr``,
``solve``,
``svd``,
``symeig``,
``triangular_solve``,
``fractional_max_pool2d``,
``fractional_max_pool3d``,
``adaptive_max_pool3d``,
``multilabel_margin_loss_forward``,
``linalg_qr``,
``linalg_cholesky_ex``,
``linalg_svd``,
``linalg_eig``,
``linalg_eigh``,
``linalg_lstsq``,
``linalg_inv_ex``

CPU Ops that promote to the widest input type
"""""""""""""""""""""""""""""""""""""""""""""
These ops don't require a particular dtype for stability, but take multiple inputs
and require that the inputs' dtypes match.  If all of the inputs are
``bfloat16``, the op runs in ``bfloat16``.  If any of the inputs is ``float32``,
autocast casts all inputs to ``float32`` and runs the op in ``float32``.

``cat``,
``stack``,
``index_copy``

Some ops not listed here (e.g., binary ops like ``add``) natively promote
inputs without autocasting's intervention.  If inputs are a mixture of ``bfloat16``
and ``float32``, these ops run in ``float32`` and produce ``float32`` output,
regardless of whether autocast is enabled.