File: programming_model.fullgraph_true.md

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (247 lines) | stat: -rw-r--r-- 8,167 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
---
file_format: mystnb
kernelspec:
  name: python3
mystnb:
  execution_timeout: 30
  execution_show_tb: True
  merge_streams: True
---

```{code-cell}
:tags: [remove-cell]
import torch

import header_code
```

# Use `fullgraph=True` to Identify and Eliminate Graph Breaks

Using `torch.compile(fullgraph=False)` (the default) is a good way to get started with `torch.compile`: it supports all Python programs out-of-the-box via the ability to graph break and gives good performance on common cases.

However, if you're trying to get more performance out of your model, you should explicitly think about what regions of code should be compiled:
- We recommend using `torch.compile(fullgraph=True)` to find and eliminate graph breaks in your code.
- If you're a library developer (or testing if your code "works" with `torch.compile`), we recommend testing using `torch.compile(fullgraph=True)`.

`torch.compile(fullgraph=True)` offers stronger guarantees over `fullgraph=False`:
we will always capture a single FX graph to be compiled (or error if we cannot due to a graph break).
**In particular, you are forced to resolve every graph break that is encountered.**

There are a number of strategies for resolving a graph break.

## Strategy 1:  Rewrite the unsupported code to use features supported by Dynamo

Many graph break error messages will give some suggestions on how to rewrite code to avoid the graph break.
If the graph break is still difficult to resolve, then please move on to the next strategy
or submit an issue to the [PyTorch GitHub repo](https://github.com/pytorch/pytorch/issues).

More graph break examples and how to resolve them can be found in [Common Graph Breaks](programming_model.common_graph_breaks).

Example: Dynamo does not support calling `next` on a `list_iterator` object that was an input to the function being compiled.

```{code-cell}
@torch.compile(fullgraph=True)
def f(xs):
    a = next(xs)
    b = next(xs)
    return a + b

xs = [torch.tensor(1.), torch.tensor(2.)]
try:
    out = f(iter(xs))
except Exception as e:
    print(e)
```

Instead, rewrite the compiled function to accept a list.

```{code-cell}
@torch.compile(fullgraph=True)
def f_rewritten(xs):
    it = iter(xs)
    a = next(it)
    b = next(it)
    return a + b

f_rewritten(xs)
```

## Strategy 2: Pure functions can always be compiled via an escape hatch.

**Summary**: The space of all Python functions is vast and thus it is impractical for Dynamo to be able to trace
through every Python function without graph breaks. For Python functions considered to be "pure"
that Dynamo cannot trace through without graph breaks, we provide some escape hatches to attempt
to trace through these functions anyway:

1. Use `custom_op` or `triton_op` on pure triton kernels.
2. Use `nonstrict_trace` for pure functions that only use PyTorch Tensor ops.
3. Use `custom_op` for all other pure functions.

A "pure function" is a function with the following properties:

- Determinism. Given the same inputs, the pure function will always return the same output
- No external side effects. A pure function does not have any externally-visible side effects,
  such as modifying external state or performing I/O operations.
  Side effects that remain internal to the function are allowed (e.g. mutating intermediate tensors).
  One notable exception is that mutating `torch.*` ops on function input Tensors are generally allowed.
- Explicit input/output. All the input data must be passed through the function parameters and all of the outputs are returned from the function.

See [Pure Functions](programming_model.non_strict_tracing_model.pure_functions) for examples.

Dynamo is theoretically able to handle a wide variety of impure functions, but may be lacking coverage for specific
Python language features. However, pure functions can always be compiled via an escape hatch.

If you have a graph break it may be possible to refactor the code around it into a pure function and use an escape hatch that bypasses Dynamo tracing:

1. Use `torch._dynamo.nonstrict_trace` if you want the Tensor operations in the function to show up in the Dynamo output graph (and therefore be optimizable). `nonstrict_trace` tells Dynamo to use **non-strict tracing**.
2. Use custom operators if you want the function to be opaque w.r.t. to `torch.compile` (both the frontend Dynamo and the backend).

Note that there is nothing preventing these escape hatches from being applied to impure functions,
but **we do not provide any soundness guarantees**.

Example: If Dynamo doesn't support some Python feature or API that is non-strict traceable (e.g. it uses PyTorch operations), [use `torch._dynamo.nonstrict_trace` to capture it instead](programming_model.dynamo_nonstrict_trace).

```{code-cell}
# this is a function that Dynamo doesn't support (due to the graph_break() call).
def g(x):
    y = x.sin()
    torch._dynamo.graph_break()
    z = y.sin()
    return z

@torch.compile(fullgraph=True)
def f(x):
    w = x.sin()
    return g(w)

x = torch.randn(3)
try:
    f(x)  # Graph Break: there was a call to torch._dynamo.graph_break()
except Exception as e:
    print(e)

@torch.compile(fullgraph=True)
def f_rewritten(x):
    w = x.sin()
    return torch._dynamo.nonstrict_trace(g)(w)
f_rewritten(x)  # works
```

Example: use [custom operators](programming_model.custom_ops) to create opaque functions w.r.t. to `torch.compile`

```{code-cell}
from torch.utils.cpp_extension import load_inline

# C++ source code for the square operation
cpp_source = """
torch::Tensor square_cpu(torch::Tensor input) {
    // Check that input is a CPU tensor
    TORCH_CHECK(input.device().is_cpu(), "Input must be a CPU tensor");

    // Create output tensor with same shape and dtype as input
    torch::Tensor output = torch::empty_like(input);

    // Get data pointers
    float* input_data = input.data_ptr<float>();
    float* output_data = output.data_ptr<float>();

    // Get total number of elements
    int64_t numel = input.numel();

    // For loop to compute square of each element
    for (int64_t i = 0; i < numel; i++) {
        output_data[i] = input_data[i] * input_data[i];
    }

    return output;
}
"""

# Load the extension inline
square_module = load_inline(
    name="square_cpu_kernel",
    cpp_sources=cpp_source,
    functions=["square_cpu"],
    verbose=True
)

def square(x):
    return square_module.square_cpu(x)

@torch.compile(fullgraph=True)
def f(x):
    return square(x)

try:
    f(torch.randn(3, 3))  # graph break
except Exception as e:
    print(e)
```

```{code-cell}
# Use torch.library.custom_op to define a new custom operator.
# Custom operators are opaque with respect to torch.compile:
# that is, torch.compile does not peek into them.

@torch.library.custom_op("mylib::square", mutates_args=())
def square(x: torch.Tensor) -> torch.Tensor:
    return square_module.square_cpu(x)

# Use register_fake to add a ``FakeTensor`` kernel for the operator
@square.register_fake
def _(x):
    return x.new_empty(x.size())

print(f(torch.randn(3, 3)))  # no graph break
```

For more information on `triton_op` for custom triton kernels, see the
[user-defined triton kernel tutorial](https://docs.pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html).


## Strategy 3: Don't compile the code

Not all code is amenable to being compiled. `torch.compile` is a compiler for Tensor computation;
it will not be able to optimize things like disk IO. Try to refactor the code such that the unsupported
code is not called in the compiled region.

```{code-cell}
@torch.compile(fullgraph=True)
def f(x):
   y = x ** 2  / 2
   torch.save(y, "foo.pt")
   z = y ** 3 / 6
   return z

x = torch.randn(3)
try:
    f(x)  # Graph Break: torch.save not supported
except Exception as e:
    print(e)
```

```{code-cell}
def f_rewritten(x):
   y = g(x)
   torch.save(y, "foo.pt")
   z = h(y)
   return z

@torch.compile(fullgraph=True)
def g(x):
   y = x ** 2  / 2
   return y

@torch.compile(fullgraph=True)
def h(y):
   z = y ** 3 / 6
   return z

f_rewritten(x)
```

```{code-cell}
:tags: [remove-cell]
import os
os.remove("foo.pt")
```