1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
|
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(graph_breaks=True)
```
# Toggling `error_on_graph_break`
**Summary:**
- When `fullgraph=False`, we can use `torch._dynamo.error_on_graph_break()` for more flexibility in
dealing with graph breaks.
So far, we have introduced two ways in dealing with graph breaks in `torch.compile`:
1. `fullgraph=True` errors on the first graph break and additionally guarantees that only one graph is traced from the code.
2. `fullgraph=False` continues tracing even when encountering graph breaks.
What if we want to disallow graph breaks for most of the code, but there are a few problematic functions where the graph breaks are hard to remove,
and we are okay with having those graph breaks? We can use `torch._dynamo.error_on_graph_break()` to achieve this.
`torch.compile` has an `error_on_graph_break` setting (initially set to `False`).
If a graph break or compiler error occurs in code while `error_on_graph_break` is set to `False`, then `torch.compile` will attempt to continue compilation after the graph break/error.
If `error_on_graph_break` is set to `True`, then `torch.compile` will abort compilation and propagate the error to user code.
A significant difference between `error_on_graph_break=True` and `fullgraph=True` is that the former **does not guarantee that a single graph will be captured**.
`error_on_graph_break` **can be arbitrarily toggled during compile time** by using the `torch._dynamo.error_on_graph_break()` context manager/decorator.
In comparison, once `fullgraph` is set to `True`, it cannot be set back to `False`.
Finally, `error_on_graph_break` has lower precedence than `fullgraph` - `error_on_graph_break` only takes effect when `fullgraph=False`.
## `error_on_graph_break(False)` example
```{code-cell}
@torch._dynamo.error_on_graph_break(False)
def code_with_a_difficult_graph_break(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner(x):
return code_with_a_difficult_graph_break(x)
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner(x)
# No error, but there is a graph break
fn(torch.randn(3))
```
Using `error_on_graph_break(False)` under `error_on_graph_break(True)` is helpful for when we want to minimize graph breaks (i.e. follow the `fullgraph=True` programming model),
but there are some sections of code with non-performance-critical graph breaks that are difficult to work around.
`error_on_graph_break()` can be used as a context manager as well:
```{code-cell}
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break() # no error
return x + 2
# No error, but there is a graph break
fn(torch.randn(3))
```
You can use monkey patching to toggle `error_on_graph_break` for code where you cannot edit the source (e.g. framework code):
```{code-cell}
class ThirdPartyModule(torch.nn.Module):
def forward(self, x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
tp_mod = ThirdPartyModule()
tp_mod.forward = torch._dynamo.error_on_graph_break(False)(tp_mod.forward)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return tp_mod.forward(x)
# No error, but there is a graph break
fn(torch.randn(3))
```
## `error_on_graph_break(True)` example
```{code-cell}
@torch._dynamo.error_on_graph_break(True)
def inner2(x):
x = x + 1
torch._dynamo.graph_break() # error
return x + 2
def inner(x):
return inner2(x)
# fullgraph=False, error_on_graph_break=False
@torch.compile
def fn(x):
x = x + 4
torch._dynamo.graph_break() # no error
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
Using `error_on_graph_break(True)` under `error_on_graph_break(False)` is helpful for when we want to use `torch.compile` flexibly (i.e. follow the `fullgraph=False` programming model),
but there are some sections of the code that are performance-critical and we want to ensure that those sections do not contain graph breaks.
## `error_on_graph_break` nesting behavior
`torch._dynamo.error_on_graph_break()` affects the `error_on_graph_break` setting of nested calls as well:
```{code-cell}
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(False):
return inner(x)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
```
`torch._dynamo.error_on_graph_break()` can be used under another `torch._dynamo.error_on_graph_break()` region:
```{code-cell}
def inner(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(True):
return inner(x)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
```
## Interaction with `fullgraph`
`fullgraph=True` takes higher precedence than `error_on_graph_break`:
```{code-cell}
@torch._dynamo.error_on_graph_break(False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
`fullgraph=True` cannot be toggled back to `fullgraph=False`:
```{code-cell}
@torch.compile(fullgraph=False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
```{code-cell}
@torch.compile(fullgraph=True)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=False)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
## Summary of `fullgraph=True/False` vs `error_on_graph_break`
Here is a table summarizing the differences between `fullgraph=True/False` and `error_on_graph_break`:
| | `error_on_graph_break=True` | `error_on_graph_break=False` (default) |
| --- | --- | --- |
| `fullgraph=True` | Graph breaks result in errors. Only the first graph break will be reported. **One graph guarantee.**<br><br>`fullgraph` cannot be toggled to `False`. `error_on_graph_break` has no effect.<br><br>User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).<br><br>Ideal for code sensitive to graph breaks: framework/library code or cases where getting maximum performance is required. Prevents downstream user code from inadvertently allowing graph breaks. | Same as `fullgraph=True` and `error_on_graph_break=True` as `error_on_graph_break` has no effect when `fullgraph=True`. |
| `fullgraph=False` (default) | Graph breaks result in errors. Only the first graph break will be reported. **No one graph guarantee.**<br><br>`error_on_graph_break` can be toggled to `False`.<br><br>User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).<br><br>Ideal for user code sensitive to graph breaks. `error_on_graph_break` can be toggled to `False` to deal with sections that have graph breaks that are difficult to work around. | Will continue to compile after encountering graph breaks. All graph breaks will be reported.<br><br>`error_on_graph_break` can be toggled to `True`.<br><br>Doesn’t require many user code changes to work. Performance may be negatively impacted due to graph breaks.<br><br>Ideal for out-of-the-box use cases, on “non-weird” code, or where squeezing maximal performance is not necessary |
|