1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
|
---
file_format: mystnb
kernelspec:
name: python3
mystnb:
execution_timeout: 30
execution_show_tb: True
merge_streams: True
---
```{code-cell}
:tags: [remove-cell]
import torch
import header_code
```
# Use `torch._dynamo.nonstrict_trace`
**Summary:**
- Use `nonstrict_trace` to trace a function with non-strict tracing inside of a `torch.compile`'d region.
You may wish to do this because the Dynamo graph breaks on something inside of the function
and you are sure that the function is non-strict traceable.
Consider the following scenario:
```{code-cell}
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
try:
func(torch.rand(10))
except Exception as e:
print(e)
```
If we run the code above, we'll get an error from Dynamo, because it sees a graph break while the user specified `fullgraph=True`.
In these situations, if a user still wants to keep `fullgraph=True`, they typically have several options:
1. The graph break is due to a language feature Dynamo doesn't yet support.
In this case, the user either rewrites their code, or files an issue on GitHub.
2. The graph break is due to a call to a function implemented in C.
In this case, the user can try to use a custom op.
The user could also try providing a polyfill (a reference implementation in Python)
so that Dynamo can trace through it.
3. Worst case scenario -- an internal compiler error. In this case, the user likely has to file an issue on GitHub.
In addition to all these options, PyTorch does provide an alternative `torch._dynamo.nonstrict_trace`, if the function call that induced the graph break satisfies certain requirements:
- The requirements of [general non-strict tracing](programming_model.non_strict_tracing_model).
- The inputs and outputs must contain either basic types (e.g., `int`, `float`, `list`, `dict`, `torch.Tensor`),
or user-defined types that are registered to `torch.utils._pytree`.
- The function must be defined outside the `torch.compile`'d region.
- Any non-input values read by the function will be treated as a constant
(e.g., a global tensor), and will not be guarded on.
When tracing through a call to a `torch._dynamo.nonstrict_trace`'d function, `torch.compile` switches to [non-strict tracing](programming_model.non_strict_tracing_model),
and the FX graph will eventually contain all the relevant tensor operations which happened inside that function.
For the example above, we can use `torch._dynamo.nonstrict_trace to eliminate` the graph break:
```{code-cell}
@torch._dynamo.nonstrict_trace
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = get_magic_num()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
```
Note that one can use it inside a `torch.compile`'d region as well:
```{code-cell}
def get_magic_num():
# This explicit graph break call is meant to emulate any kind of Dynamo
# graph break, e.g., the function is implemented in C, or uses some python
# language feature Dynamo doesn't yet support.
torch._dynamo.graph_break()
return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
n = torch._dynamo.nonstrict_trace(get_magic_num)()
return x + n
print(func(torch.rand(10)))
# No graph break and no error.
```
|