File: JIT-AUTOCAST.md

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (203 lines) | stat: -rw-r--r-- 6,626 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203

# JIT scripting & Autocast

<!-- @import "[TOC]" {cmd="toc" depthFrom=2 depthTo=6 orderedList=false} -->

<!-- code_chunk_output -->

- [Overview](#overview)
- [Usage](#usage)
- [Known limitations](#known-limitations)
    - [Diagnostics](#diagnostics)
    - [Autocast decorators](#autocast-decorators)
    - [Autocast argument must be a compile-time constant](#autocast-argument-must-be-a-compile-time-constant)
    - [Uncommon autocast usage patterns may not be supported](#uncommon-autocast-usage-patterns-may-not-be-supported)
    - [Limited support for promote autocast policy](#limited-support-for-promote-autocast-policy)
    - [Missing autocast policies](#missing-autocast-policies)
    - [Mixing eager mode and scripting autocast](#mixing-eager-mode-and-scripting-autocast)
    - [Mixing tracing and scripting autocast (script calling traced)](#mixing-tracing-and-scripting-autocast-script-calling-traced)
    - [Mixing tracing and scripting autocast (traced calling script)](#mixing-tracing-and-scripting-autocast-traced-calling-script)
    - [Disabling eager autocast with scripted autocast](#disabling-eager-autocast-with-scripted-autocast)
- [References](#references)

<!-- /code_chunk_output -->

## Overview

[Autocast][2] (aka Automatic Mixed Precision) is an optimization which helps
taking advantage of the storage and performance benefits of narrow types
(float16) while preserving the additional range and numerical precision of
float32.

The JIT support for autocast is subject to different constraints compared to the
eager mode implementation (mostly related to the fact that TorchScript is
statically typed) and this document attempts to list the known limitations.

## Usage

Explicit `with autocast()` scopes are supported inside scripted functions and
modules (subject to the limitations described below):

```python
import torch
from torch.cuda.amp import autocast

@torch.jit.script
def func(a, b):
    with autocast():
        return torch.mm(a, b)

a_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
b_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
result = func(a_float32, b_float32)
print(result.dtype) # expecting torch.float16
```

## Known limitations

This section documents the current set of known limitations. Ideally this list
will shrink as we advance with the design and implementation, although some of
the limitations are related to fundamental TorchScript aspects that are not easy
to change.

> One important goal is to avoid surprises (ex. autocast annotations
> silently ignored) and to report sensible diagnostics when something deviates
> from eager mode behavior.
>
> Please [report](https://github.com/csarofeen/pytorch/issues/new/choose) any
> issues not covered here.

#### Diagnostics

The current Autocast/JIT diagnostics should be improved:
- Some errors are not specific enough or not actionable
- Not all the errors point to the Python source location

#### Autocast decorators

Using `@autocast` is not currently supported in script mode (a diagnostic
will be emitted)

```python
@autocast(enabled=True)
def helper(x):
    ...

@torch.jit.script
def foo(x):
    return helper(x) # not supported
```

Another example

```python
@torch.jit.script
@autocast() # not supported
def foo(a, b, c, d):
    ...
```

#### Autocast argument must be a compile-time constant

```python
@torch.jit.script
def fn(a, b, use_amp: bool):
    # runtime values for autocast enable argument are not supported
    with autocast(enabled=use_amp):
        return torch.mm(a, b)

```

#### Uncommon autocast usage patterns may not be supported

```python
@torch.jit.script
def fn(a, b, c, d):
    with autocast(enabled=True) as autocast_instance: # not supported
        ...
        with autocast_instance:
            ...
```

#### Limited support for promote autocast policy

For some operations, autocast needs to [promote to the widest argument type][3].
When the concrete types are not available, the current implementation will
conservatively inject a promotion even when it may not be needed.

#### Missing autocast policies

Also related to the lack of concrete dtype availability, a few specialized
autocast policies are not yet supported with JIT scripting:
- [CastPolicy::fp32_append_dtype][5]

#### Mixing tracing and scripting autocast (script calling traced)

Calling a traced function from a scripted one mostly works, except for the case
where the traced part uses `autocast(False)`. After tracing, the `autocast` is
stripped from the TorchScript IR so it's effectively ignored:

> This is one known limitation where we don't have a way to emit a diagnostic!

```python
def helper(a, b):
    with autocast(enabled=False):
        return torch.mm(a, b) * 2.0

traced = torch.jit.trace(helper, (x, y))

@torch.jit.script
def fn(a, b):
    with autocast(enabled=True):
        return traced(a, b)
```

#### Mixing tracing and scripting autocast (traced calling script)

Calling a scripted function from a trace is similar to calling the scripted
function from eager mode:

```python
@torch.jit.script
def fn(a, b):
    return torch.mm(a, b)

def traced(a, b):
    with autocast(enabled=True):
        return fn(a, b)

# running TorchScript with Autocast enabled is not supported
torch.jit.trace(traced, (x, y))
```

#### Disabling eager autocast with scripted autocast

If eager-mode autocast is enabled and we try to disable autocasting from
within a scripted function, autocasting will still occur.

```python
@torch.jit.script
def fn(a, b):
    with autocast(enabled=False):
        return torch.mm(a, b)

x = torch.rand((2, 2), device='cuda', dtype=torch.float)
y = torch.rand((2, 2), device='cuda', dtype=torch.float)

# this will print half-precision dtype
with autocast(enabled=True):
    print(fn(x, y).dtype)
```

## References

- [torch.cuda.amp Package][1]
- [Automatic Mixed Precision - Tutorial](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html)
- [Automatic Mixed Precision - Examples](https://pytorch.org/docs/stable/notes/amp_examples.html)

[1]: https://pytorch.org/docs/stable/amp.html
[2]: https://pytorch.org/blog/accelerating-training-on-nvidia-gpus-with-pytorch-automatic-mixed-precision/
[3]: https://pytorch.org/docs/stable/amp.html#ops-that-promote-to-the-widest-input-type
[4]: https://github.com/csarofeen/pytorch/blob/4d8575604ad9fa5fdfc21037490a041d8d43bcae/aten/src/ATen/autocast_mode.cpp#L94
[5]: https://github.com/csarofeen/pytorch/blob/4d8575604ad9fa5fdfc21037490a041d8d43bcae/aten/src/ATen/autocast_mode.cpp#L99
[6]: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html#adding-autocast