1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
|
# Nested Graph Breaks
Summary:
- Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below
- A nested graph break results in {math}`\mathcal O(N)` duplicate graph break behavior
Recall that when `torch.compile` is applied to a function, any nested function calls are also traced.
A **nested graph break** refers to any graph break that happens in a nested function call.
```python
def inner(x):
...
torch._dynamo.graph_break() # nested graph break
...
@torch.compile
def outer(x):
...
y = inner(x)
...
```
The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.
Recall that in `fullgraph=False`, [graph breaks are handled](programming_model.dynamo_core_concepts.graph_breaks) by compiling the FX graph that has been determined so far,
running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph.
Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.
We can therefore resume tracing after a nested graph break with this restriction in the following way:
First, consider the below example where `torch.compile` traces from `f` and traces all the way until the
graph break in `inner1` is encountered.
```python
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def inner2(x):
x = x + 4
x = inner1(x)
x = x + 8
@torch.compile
def f(x):
# start tracing from here
x = x + 16
x = inner2(x)
x = x + 32
f(torch.randn(3))
```
Since we can only resume from top-level functions, we graph break on the `inner2` call in `f`.
```python
# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
```
`inner2` is then automatically compiled as a top-level function.
We trace all the way until the graph break in `inner1` is encountered again.
```python
def inner1(x):
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
# this torch.compile is automatically applied
@torch.compile
def inner2(x):
# start tracing from here
x = x + 4
x = inner1(x)
x = x + 8
def compiled_f_semantics(x):
y = x + 16
z = inner2(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
compiled_f_semantics(torch.randn(3))
```
Then we graph break on the `inner1` call in `inner2`.
```python
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
```
`inner1` is then automatically compiled as a top-level function.
The graph break is from `inner1`, so we handle the graph break normally.
```python
# this torch.compile is automatically applied
@torch.compile
def inner1(x):
# start tracing from here
x = x + 1
torch._dynamo.graph_break() # stop tracing due to graph break
return x + 2
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = inner1(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
compiled_f_semantics(torch.randn(3))
```
`inner1` is handled normally:
```python
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
```
So the initial code is semantically equivalent to
```python
def compiled_f_semantics(x):
y = x + 16
z = compiled_inner2_semantics(y)
return torch.compile(resume_f_semantics)(z)
def resume_f_semantics(x):
return x + 32
def compiled_inner2_semantics(x):
y = x + 4
z = compiled_inner1_semantics(y)
return torch.compile(resume_inner2_semantics)(z)
def resume_inner2_semantics(x):
return x + 8
def compiled_inner1_semantics(x):
y = x + 1
torch._dynamo.graph_break()
return torch.compile(resume_inner1_semantics)(y)
def resume_inner1_semantics(x):
return x + 2
compiled_f_semantics(torch.randn(3))
```
Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times.
**This explains why you may encounter duplicate graph breaks when using `torch.compile`.**
In summary, nested graph breaks are handled by:
- Tracing from the top-level function all the way to the nested graph break
- Graph breaking on the top-level function at the call to the second-level function
- Compiling the PyTorch ops tracked so far and running the compiled graph
- Calling the second-level function, which gets automatically compiled as a top-level function
- Resuming tracing after the second-level function call
Note that the runtime of handling this graph break is {math}`\mathcal O(NK)`, where {math}`N` is the nesting depth,
and {math}`K` is the number of instructions from the top-level function to the graph break.
We end up tracing {math}`\mathcal O(N^2)` frames, and we trace the same graph break {math}`\mathcal O(N)` times.
|