File: torch.compiler_custom_backends.rst

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (288 lines) | stat: -rw-r--r-- 11,622 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
Custom Backends
===============

Overview
--------

``torch.compile`` provides a straightforward method to enable users
to define custom backends.

A backend function has the contract
``(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]) -> Callable``.

Backend functions can be called by TorchDynamo, the graph tracing component of ``torch.compile``,
after tracing an FX graph and are
expected to return a compiled function that is equivalent to the traced FX graph.
The returned callable should have the same contract as the ``forward`` function of the original ``torch.fx.GraphModule``
passed into the backend:
``(*args: torch.Tensor) -> List[torch.Tensor]``.

In order for TorchDynamo to call your backend, pass your backend function as the ``backend`` kwarg in
``torch.compile``. For example,

.. code-block:: python

    import torch

    def my_custom_backend(gm, example_inputs):
        return gm.forward

    def f(...):
        ...

    f_opt = torch.compile(f, backend=my_custom_backend)

    @torch.compile(backend=my_custom_backend)
    def g(...):
        ...

See below for more examples.

Registering Custom Backends
---------------------------

You can register your backend using the ``register_backend`` decorator, for example,

.. code-block:: python

    from torch._dynamo import register_backend

    @register_backend
    def my_compiler(gm, example_inputs):
        ...

Besides the ``register_backend`` decorator, if your backend is in another python package, you could also register your
backend through entry points of python package, which provides a way for a package to register a plugin for another one.

.. hint::

    You can learn more about ``entry_points`` in the
    `python packaging documentation <https://setuptools.pypa.io/en/latest/userguide/entry_point.html>`__.

To register your backend through ``entry_points``, you could add your backend function to the ``torch_dynamo_backends`` entry point group in the
``setup.py`` file of your package like:

.. code-block:: python

    ...
    setup(
        ...
        'torch_dynamo_backends': [
            'my_compiler = your_module.submodule:my_compiler',
        ]
        ...
    )

Please replace the ``my_compiler`` before ``=`` to the name of your backend's name and replace the part after ``=`` to
the module and function name of your backend function.
The entry point will be added to your python environment after the installation of the package.
When you call ``torch.compile(model, backend="my_compiler")``, PyTorch would first search the backend named ``my_compiler``
that has been registered with ``register_backend``. If not found, it will continue to search in all backends registered
via ``entry_points``.

Registration serves two purposes:

* You can pass a string containing your backend function's name to ``torch.compile`` instead of the function itself,
  for example, ``torch.compile(model, backend="my_compiler")``.
* It is required for use with the :ref:`minifier <torch.compiler_troubleshooting_old>`. Any generated
  code from the minifier must call your code that registers your backend function, typically through an ``import`` statement.

Custom Backends after AOTAutograd
---------------------------------

It is possible to define custom backends that are called by AOTAutograd rather than TorchDynamo.
This is useful for 2 main reasons:

* Users can define backends that support model training, as AOTAutograd can generate the backward graph for compilation.
* AOTAutograd produces FX graphs consisting of `core Aten ops <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>`__. As a result,
  custom backends only need to support the core Aten opset, which is a significantly smaller opset than the entire torch/Aten opset.

Wrap your backend with
``torch._dynamo.backends.common.aot_autograd`` and use ``torch.compile`` with the ``backend`` kwarg as before.
Backend functions wrapped by ``aot_autograd`` should have the same contract as before.

Backend functions are passed to ``aot_autograd`` through the ``fw_compiler`` (forward compiler)
or ``bw_compiler`` (backward compiler) kwargs. If ``bw_compiler`` is not specified, the backward compile function
defaults to the forward compile function.

One caveat is that AOTAutograd requires compiled functions returned by backends to be "boxed". This can be done by wrapping
the compiled function with ``functorch.compile.make_boxed_func``.

For example,

.. code-block:: python

    from torch._dynamo.backends.common import aot_autograd
    from functorch.compile import make_boxed_func

    def my_compiler(gm, example_inputs):
        return make_boxed_func(gm.forward)

    my_backend = aot_autograd(fw_compiler=my_compiler)  # bw_compiler=my_compiler

    model_opt = torch.compile(model, backend=my_backend)

Examples
--------

Debugging Backend
^^^^^^^^^^^^^^^^^

If you want to better understand what is going on during a
compilation, you can create a custom compiler, which is referred to as
backend in this section, that will print pretty print the fx
``GraphModule`` extracted from Dynamo’s bytecode analysis
and return a ``forward()`` callable.

For example:

.. code-block:: python

    from typing import List
    import torch
    def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
        print("my_compiler() called with FX graph:")
        gm.graph.print_tabular()
        return gm.forward  # return a python callable
    @torch.compile(backend=my_compiler)
    def fn(x, y):
        a = torch.cos(x)
        b = torch.sin(y)
        return a + b
    fn(torch.randn(10), torch.randn(10))

Running the above example produces the following output:

::

    my_compiler() called with FX graph:
    opcode         name    target                                                  args        kwargs
    -------------  ------  ------------------------------------------------------  ----------  --------
    placeholder    x       x                                                       ()          {}
    placeholder    y       y                                                       ()          {}
    call_function  cos     <built-in method cos of type object at 0x7f1a894649a8>  (x,)        {}
    call_function  sin     <built-in method sin of type object at 0x7f1a894649a8>  (y,)        {}
    call_function  add     <built-in function add>                                 (cos, sin)  {}
    output         output  output                                                  ((add,),)   {}

This works for ``torch.nn.Module`` as well as shown below:

.. code-block:: python

    from typing import List
    import torch
    def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
        print("my_compiler() called with FX graph:")
        gm.graph.print_tabular()
        return gm.forward  # return a python callable
    class MockModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.relu = torch.nn.ReLU()
        def forward(self, x):
            return self.relu(torch.cos(x))
    mod = MockModule()
    optimized_mod = torch.compile(mod, backend=my_compiler)
    optimized_mod(torch.randn(10))

Let’s take a look at one more example with control flow:

.. code-block:: python

    from typing import List
    import torch
    def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
        print("my_compiler() called with FX graph:")
        gm.graph.print_tabular()
        return gm.forward  # return a python callable
    @torch.compile(backend=my_compiler)
    def toy_example(a, b):
        x = a / (torch.abs(a) + 1)
        if b.sum() < 0:
            b = b * -1
        return x * b
    for _ in range(100):
        toy_example(torch.randn(10), torch.randn(10))

Running this example produces the following output:

::

    my_compiler() called with FX graph:
    opcode         name     target                                                  args              kwargs
    -------------  -------  ------------------------------------------------------  ----------------  --------
    placeholder    a        a                                                       ()                {}
    placeholder    b        b                                                       ()                {}
    call_function  abs_1    <built-in method abs of type object at 0x7f8d259298a0>  (a,)              {}
    call_function  add      <built-in function add>                                 (abs_1, 1)        {}
    call_function  truediv  <built-in function truediv>                             (a, add)          {}
    call_method    sum_1    sum                                                     (b,)              {}
    call_function  lt       <built-in function lt>                                  (sum_1, 0)        {}
    output         output   output                                                  ((truediv, lt),)  {}

    my_compiler() called with FX graph:
    opcode         name    target                   args         kwargs
    -------------  ------  -----------------------  -----------  --------
    placeholder    b       b                        ()           {}
    placeholder    x       x                        ()           {}
    call_function  mul     <built-in function mul>  (b, -1)      {}
    call_function  mul_1   <built-in function mul>  (x, mul)     {}
    output         output  output                   ((mul_1,),)  {}

    my_compiler() called with FX graph:
    opcode         name    target                   args       kwargs
    -------------  ------  -----------------------  ---------  --------
    placeholder    b       b                        ()         {}
    placeholder    x       x                        ()         {}
    call_function  mul     <built-in function mul>  (x, b)     {}
    output         output  output                   ((mul,),)  {}

    The order of the last two graphs is nondeterministic depending
    on which one is encountered first by the just-in-time compiler.

Speedy Backend
^^^^^^^^^^^^^^

Integrating a custom backend that offers superior performance is also
easy and we’ll integrate a real one
with `optimize_for_inference <https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html>`__:

.. code-block:: python

    def optimize_for_inference_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
        scripted = torch.jit.script(gm)
        return torch.jit.optimize_for_inference(scripted)

And then you should be able to optimize any existing code with:

.. code-block:: python

    @torch.compile(backend=optimize_for_inference_compiler)
    def code_to_accelerate():
        ...

Composable Backends
^^^^^^^^^^^^^^^^^^^

TorchDynamo includes many backends, which can be listed with
``torch._dynamo.list_backends()``. You can combine these backends
together with the following code:

.. code-block:: python

    from torch._dynamo import lookup_backend
    def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
        try:
            trt_compiled = lookup_backend("tensorrt")(gm, example_inputs)
            if trt_compiled is not None:
                return trt_compiled
        except Exception:
            pass
        # first backend failed, try something else...
        try:
            inductor_compiled = lookup_backend("inductor")(gm, example_inputs)
            if inductor_compiled is not None:
                return inductor_compiled
        except Exception:
            pass
        return gm.forward