File: onnx.rst

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (652 lines) | stat: -rw-r--r-- 27,355 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
torch.onnx
==========

.. contents:: :local:

.. automodule:: torch.onnx

`Open Neural Network eXchange (ONNX) <https://onnx.ai/>`_ is an open standard
format for representing machine learning models. The torch.onnx module can export
PyTorch models to ONNX. The model can then be consumed by any of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.

Example: AlexNet from PyTorch to ONNX
-------------------------------------

Here is a simple script which exports a pretrained AlexNet to an ONNX file named ``alexnet.onnx``.
The call to ``torch.onnx.export`` runs the model once to trace its execution and then exports the
traced model to the specified file::

    import torch
    import torchvision

    dummy_input = torch.randn(10, 3, 224, 224, device="cuda")
    model = torchvision.models.alexnet(pretrained=True).cuda()

    # Providing input and output names sets the display names for values
    # within the model's graph. Setting these does not change the semantics
    # of the graph; it is only for readability.
    #
    # The inputs to the network consist of the flat list of inputs (i.e.
    # the values you would pass to the forward() method) followed by the
    # flat list of parameters. You can partially specify names, i.e. provide
    # a list here shorter than the number of inputs to the model, and we will
    # only set that subset of names, starting from the beginning.
    input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
    output_names = [ "output1" ]

    torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names, output_names=output_names)

The resulting ``alexnet.onnx`` file contains a binary `protocol buffer <https://developers.google.com/protocol-buffers/>`_
which contains both the network structure and parameters of the model you exported
(in this case, AlexNet).  The argument ``verbose=True`` causes the
exporter to print out a human-readable representation of the model::

    # These are the inputs and parameters to the network, which have taken on
    # the names we specified earlier.
    graph(%actual_input_1 : Float(10, 3, 224, 224)
          %learned_0 : Float(64, 3, 11, 11)
          %learned_1 : Float(64)
          %learned_2 : Float(192, 64, 5, 5)
          %learned_3 : Float(192)
          # ---- omitted for brevity ----
          %learned_14 : Float(1000, 4096)
          %learned_15 : Float(1000)) {
      # Every statement consists of some output tensors (and their types),
      # the operator to be run (with its attributes, e.g., kernels, strides,
      # etc.), its input tensors (%actual_input_1, %learned_0, %learned_1)
      %17 : Float(10, 64, 55, 55) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[11, 11], pads=[2, 2, 2, 2], strides=[4, 4]](%actual_input_1, %learned_0, %learned_1), scope: AlexNet/Sequential[features]/Conv2d[0]
      %18 : Float(10, 64, 55, 55) = onnx::Relu(%17), scope: AlexNet/Sequential[features]/ReLU[1]
      %19 : Float(10, 64, 27, 27) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%18), scope: AlexNet/Sequential[features]/MaxPool2d[2]
      # ---- omitted for brevity ----
      %29 : Float(10, 256, 6, 6) = onnx::MaxPool[kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[2, 2]](%28), scope: AlexNet/Sequential[features]/MaxPool2d[12]
      # Dynamic means that the shape is not known. This may be because of a
      # limitation of our implementation (which we would like to fix in a
      # future release) or shapes which are truly dynamic.
      %30 : Dynamic = onnx::Shape(%29), scope: AlexNet
      %31 : Dynamic = onnx::Slice[axes=[0], ends=[1], starts=[0]](%30), scope: AlexNet
      %32 : Long() = onnx::Squeeze[axes=[0]](%31), scope: AlexNet
      %33 : Long() = onnx::Constant[value={9216}](), scope: AlexNet
      # ---- omitted for brevity ----
      %output1 : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, broadcast=1, transB=1](%45, %learned_14, %learned_15), scope: AlexNet/Sequential[classifier]/Linear[6]
      return (%output1);
    }

You can also verify the output using the `ONNX <https://github.com/onnx/onnx/>`_ library,
which you can install using ``pip``::

    pip install onnx

Then, you can run::

    import onnx

    # Load the ONNX model
    model = onnx.load("alexnet.onnx")

    # Check that the model is well formed
    onnx.checker.check_model(model)

    # Print a human readable representation of the graph
    print(onnx.helper.printable_graph(model.graph))

You can also run the exported model with one of the many
`runtimes that support ONNX <https://onnx.ai/supported-tools.html#deployModel>`_.
For example after installing `ONNX Runtime <https://www.onnxruntime.ai>`_, you can
load and run the model::

    import onnxruntime as ort

    ort_session = ort.InferenceSession("alexnet.onnx")

    outputs = ort_session.run(
        None,
        {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
    )
    print(outputs[0])

Here is a more involved `tutorial on exporting a model and running it with ONNX Runtime <https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html>`_.

.. _tracing-vs-scripting:

Tracing vs Scripting
--------------------

Internally, :func:`torch.onnx.export()` requires a :class:`torch.jit.ScriptModule` rather than
a :class:`torch.nn.Module`. If the passed-in model is not already a ``ScriptModule``,
``export()`` will use *tracing* to convert it to one:

.. TODO(justinchuby): Add a word on recommending tracing over scripting for most use cases.

* **Tracing**: If ``torch.onnx.export()`` is called with a Module that is not already a
  ``ScriptModule``, it first does the equivalent of :func:`torch.jit.trace`, which executes the model
  once with the given ``args`` and records all operations that happen during that execution. This
  means that if your model is dynamic, e.g., changes behavior depending on input data, the exported
  model will *not* capture this dynamic behavior.
  We recommend examining the exported model and making sure the operators look
  reasonable. Tracing will unroll loops and if statements, exporting a static graph that is exactly
  the same as the traced run. If you want to export your model with dynamic control flow, you will
  need to use *scripting*.

* **Scripting**: Compiling a model via scripting preserves dynamic control flow and is valid for inputs
  of different sizes. To use scripting:

  * Use :func:`torch.jit.script` to produce a ``ScriptModule``.
  * Call ``torch.onnx.export()`` with the ``ScriptModule`` as the model. The ``args`` are still required,
    but they will be used internally only to produce example outputs, so that the types and shapes of the
    outputs can be captured. No tracing will be performed.

See `Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_
and `TorchScript <jit.html>`_ for more details, including how to compose tracing and scripting to suit the
particular requirements of different models.


Avoiding Pitfalls
-----------------

Avoid NumPy and built-in Python types
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

PyTorch models can be written using NumPy or Python types and functions, but
during :ref:`tracing<tracing-vs-scripting>`, any variables of NumPy or Python
types (rather than torch.Tensor) are converted to constants, which will produce
the wrong result if those values should change depending on the inputs.

For example, rather than using numpy functions on numpy.ndarrays: ::

    # Bad! Will be replaced with constants during tracing.
    x, y = np.random.rand(1, 2), np.random.rand(1, 2)
    np.concatenate((x, y), axis=1)

Use torch operators on torch.Tensors: ::

    # Good! Tensor operations will be captured during tracing.
    x, y = torch.randn(1, 2), torch.randn(1, 2)
    torch.cat((x, y), dim=1)


And rather than use :func:`torch.Tensor.item` (which converts a Tensor to a Python
built-in number): ::

    # Bad! y.item() will be replaced with a constant during tracing.
    def forward(self, x, y):
        return x.reshape(y.item(), -1)

Use torch's support for implicit casting of single-element tensors: ::

    # Good! y will be preserved as a variable during tracing.
    def forward(self, x, y):
        return x.reshape(y, -1)

Avoid Tensor.data
^^^^^^^^^^^^^^^^^

Using the Tensor.data field can produce an incorrect trace and therefore an incorrect ONNX graph.
Use :func:`torch.Tensor.detach` instead. (Work is ongoing to
`remove Tensor.data entirely <https://github.com/pytorch/pytorch/issues/30987>`_).

Avoid in-place operations when using tensor.shape in tracing mode
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In tracing mode, shapes obtained from ``tensor.shape`` are traced as tensors,
and share the same memory. This might cause a mismatch the final output values.
As a workaround, avoid the use of inplace operations in these scenarios.
For example, in the model::

    class Model(torch.nn.Module):
      def forward(self, states):
          batch_size, seq_length = states.shape[:2]
          real_seq_length = seq_length
          real_seq_length += 2
          return real_seq_length + seq_length

``real_seq_length`` and ``seq_length`` share the same memory in tracing mode.
This could be avoided by rewriting the inplace operation::

    real_seq_length = real_seq_length + 2

Limitations
-----------

Types
^^^^^

* Only :class:`torch.Tensors`, numeric types that can be trivially converted to torch.Tensors (e.g. float, int),
  and tuples and lists of those types are supported as model inputs or outputs. Dict and str inputs and
  outputs are accepted in :ref:`tracing<tracing-vs-scripting>` mode, but:

  * Any computation that depends on the value of a dict or a str input **will be replaced with the
    constant value** seen during the one traced execution.
  * Any output that is a dict will be silently replaced with a **flattened sequence of its values
    (keys will be removed)**. E.g. ``{"foo": 1, "bar": 2}`` becomes ``(1, 2)``.
  * Any output that is a str will be silently removed.

* Certain operations involving tuples and lists are not supported in
  :ref:`scripting<tracing-vs-scripting>` mode due to limited support in ONNX for nested sequences.
  In particular appending a tuple to a list is not supported. In tracing mode, the nested sequences
  will be flattened automatically during the tracing.

Differences in Operator Implementations
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Due to differences in implementations of operators, running the exported model on different runtimes
may produce different results from each other or from PyTorch. Normally these differences are
numerically small, so this should only be a concern if your application is sensitive to these
small differences.

.. _tensor-indexing:

Unsupported Tensor Indexing Patterns
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Tensor indexing patterns that cannot be exported are listed below.
If you are experiencing issues exporting a model that does not include any of
the unsupported patterns below, please double check that you are exporting with
the latest ``opset_version``.

Reads / Gets
~~~~~~~~~~~~

When indexing into a tensor for reading, the following patterns are not supported: ::

  # Tensor indices that includes negative values.
  data[torch.tensor([[1, 2], [2, -3]]), torch.tensor([-2, 3])]
  # Workarounds: use positive index values.

Writes / Sets
~~~~~~~~~~~~~

When indexing into a Tensor for writing, the following patterns are not supported: ::

  # Multiple tensor indices if any has rank >= 2
  data[torch.tensor([[1, 2], [2, 3]]), torch.tensor([2, 3])] = new_data
  # Workarounds: use single tensor index with rank >= 2,
  #              or multiple consecutive tensor indices with rank == 1.

  # Multiple tensor indices that are not consecutive
  data[torch.tensor([2, 3]), :, torch.tensor([1, 2])] = new_data
  # Workarounds: transpose `data` such that tensor indices are consecutive.

  # Tensor indices that includes negative values.
  data[torch.tensor([1, -2]), torch.tensor([-2, 3])] = new_data
  # Workarounds: use positive index values.

  # Implicit broadcasting required for new_data.
  data[torch.tensor([[0, 2], [1, 1]]), 1:3] = new_data
  # Workarounds: expand new_data explicitly.
  # Example:
  #   data shape: [3, 4, 5]
  #   new_data shape: [5]
  #   expected new_data shape after broadcasting: [2, 2, 2, 5]

Adding support for operators
----------------------------

When exporting a model that includes unsupported operators, you'll see an error message like:

.. code-block:: text

    RuntimeError: ONNX export failed: Couldn't export operator foo

When that happens, there are a few things you can do:

#. Change the model to not use that operator.
#. Create a symbolic function to convert the operator and register it as a custom symbolic function.
#. Contribute to PyTorch to add the same symbolic function to :mod:`torch.onnx` itself.

If you decided to implement a symbolic function (we hope you will contribute it back to PyTorch!), here is how you can get started:

ONNX exporter internals
^^^^^^^^^^^^^^^^^^^^^^^

A "symbolic function" is a function that decomposes a PyTorch operator into a
composition of a series of ONNX operators.

During export, each node (which contains a PyTorch operator) in the TorchScript
graph is visited by the exporter in topological order.
Upon visiting a node, the exporter looks for a registered symbolic functions for
that operator. Symbolic functions are implemented in Python. A symbolic function for
an op named ``foo`` would look something like::


    def foo(
      g,
      input_0: torch._C.Value,
      input_1: torch._C.Value) -> Union[None, torch._C.Value, List[torch._C.Value]]:
      """
      Adds the ONNX operations representing this PyTorch function by updating the
      graph g with `g.op()` calls.

      Args:
        g (Graph): graph to write the ONNX representation into.
        input_0 (Value): value representing the variables which contain
            the first input for this operator.
        input_1 (Value): value representing the variables which contain
            the second input for this operator.

      Returns:
        A Value or List of Values specifying the ONNX nodes that compute something
        equivalent to the original PyTorch operator with the given inputs.

        None if it cannot be converted to ONNX.
      """
      ...

The ``torch._C`` types are Python wrappers around the types defined in C++ in
`ir.h <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h>`_.

The process for adding a symbolic function depends on the type of operator.

.. _adding-support-aten:

ATen operators
^^^^^^^^^^^^^^

`ATen <https://pytorch.org/cppdocs/#aten>`_ is PyTorch's built-in tensor library.
If the operator is an ATen operator (shows up in the TorchScript graph with the prefix
``aten::``), make sure it is not supported already.

List of supported operators
~~~~~~~~~~~~~~~~~~~~~~~~~~~

Visit the auto generated :doc:`list of supported TorchScript operators <../onnx_supported_aten_ops>`
for details on which operator are supported in each ``opset_version``.

Adding support for an aten or quantized operator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If the operator is not in the list above:

* Define the symbolic function in ``torch/onnx/symbolic_opset<version>.py``, for example
  `torch/onnx/symbolic_opset9.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic_opset9.py>`_.
  Make sure the function has the same name as the ATen function, which may be declared in
  ``torch/_C/_VariableFunctions.pyi`` or ``torch/nn/functional.pyi`` (these files are generated at
  build time, so will not appear in your checkout until you build PyTorch).
* By default, the first arg is the ONNX graph.
  Other arg names must EXACTLY match the names in the ``.pyi`` file,
  because dispatch is done with keyword arguments.
* In the symbolic function, if the operator is in the
  `ONNX standard operator set <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
  we only need to create a node to represent the ONNX operator in the graph.
  If not, we can compose several standard operators that have the
  equivalent semantics to the ATen operator.

Here is an example of handling missing symbolic function for the ``ELU`` operator.

If we run the following code::

    print(
        torch.jit.trace(
            torch.nn.ELU(), # module
            torch.ones(1)   # example input
        ).graph
    )

We see something like::

  graph(%self : __torch__.torch.nn.modules.activation.___torch_mangle_0.ELU,
        %input : Float(1, strides=[1], requires_grad=0, device=cpu)):
    %4 : float = prim::Constant[value=1.]()
    %5 : int = prim::Constant[value=1]()
    %6 : int = prim::Constant[value=1]()
    %7 : Float(1, strides=[1], requires_grad=0, device=cpu) = aten::elu(%input, %4, %5, %6)
    return (%7)

Since we see ``aten::elu`` in the graph, we know this is an ATen operator.

We check the `ONNX operator list <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
and confirm that ``Elu`` is standardized in ONNX.

We find a signature for ``elu`` in ``torch/nn/functional.pyi``::

    def elu(input: Tensor, alpha: float = ..., inplace: bool = ...) -> Tensor: ...

We add the following lines to ``symbolic_opset9.py``::

    def elu(g, input: torch.Value, alpha: torch.Value, inplace: bool = False):
        return g.op("Elu", input, alpha_f=alpha)

Now PyTorch is able to export models containing the ``aten::elu`` operator!

See the ``torch/onnx/symbolic_opset*.py`` files for more examples.


torch.autograd.Functions
^^^^^^^^^^^^^^^^^^^^^^^^

If the operator is a sub-class of :class:`torch.autograd.Function`, there are three ways
to export it.

Static Symbolic Method
~~~~~~~~~~~~~~~~~~~~~~

You can add a static method named ``symbolic`` to your function class. It should return
ONNX operators that represent the function's behavior in ONNX. For example::

    class MyRelu(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input: torch.Tensor) -> torch.Tensor:
            ctx.save_for_backward(input)
            return input.clamp(min=0)

        @staticmethod
        def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
            return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

.. FIXME(justinchuby): PythonOps are too complicated and the example below
..  uses private methods we do not expose. We are looking to
..  improve the experience. Since SymbolicContext is deprecated, we think
..  defining a symbolic staticmethod is a better way to go for now.

.. PythonOp Symbolic
.. ~~~~~~~~~~~~~~~~~

.. Alternatively, you can register a custom symbolic function.
.. This gives the symbolic function access to more info through the
.. ``torch.onnx.SymbolicContext`` object, which gets passed in as the first
.. argument (before the ``Graph`` object).

.. All autograd ``Function``\ s appear in the TorchScript graph as ``prim::PythonOp`` nodes.
.. In order to differentiate between different ``Function`` subclasses, the
.. symbolic function should use the ``name`` kwarg which gets set to the name of the class.

.. Custom symbolic functions should add type and shape information by calling ``setType(...)``
.. on Value objects before returning them (implemented in C++ by
.. . ``torch::jit::Value::setType``). This is not required, but it can help the exporter's
.. shape and type inference for down-stream nodes. For a non-trivial example of ``setType``, see
.. ``test_aten_embedding_2`` in
.. `test_operators.py <https://github.com/pytorch/pytorch/blob/master/test/onnx/test_operators.py>`_.

.. The example below shows how you can access ``requires_grad`` via the ``Node`` object:

..     class MyClip(torch.autograd.Function):
..         @staticmethod
..         def forward(ctx, input, min):
..             ctx.save_for_backward(input)
..             return input.clamp(min=min)

..     class MyRelu(torch.autograd.Function):
..         @staticmethod
..         def forward(ctx, input):
..             ctx.save_for_backward(input)
..             return input.clamp(min=0)

..     def symbolic_python_op(g: "GraphContext", *args, **kwargs):
..         n = ctx.cur_node
..         print("original node: ", n)
..         for i, out in enumerate(n.outputs()):
..             print("original output {}: {}, requires grad: {}".format(i, out, out.requiresGrad()))
..         import torch.onnx.symbolic_helper as sym_helper
..         for i, arg in enumerate(args):
..             requires_grad = arg.requiresGrad() if sym_helper._is_value(arg) else False
..             print("arg {}: {}, requires grad: {}".format(i, arg, requires_grad))

..         name = kwargs["name"]
..         ret = None
..         if name == "MyClip":
..             ret = g.op("Clip", args[0], args[1])
..         elif name == "MyRelu":
..             ret = g.op("Relu", args[0])
..         else:
..             # Logs a warning and returns None
..             return _unimplemented("prim::PythonOp", "unknown node kind: " + name)
..         # Copy type and shape from original node.
..         ret.setType(n.type())
..         return ret

..     from torch.onnx import register_custom_op_symbolic
.. .     register_custom_op_symbolic("prim::PythonOp", symbolic_python_op, 1)

Inline Autograd Function
~~~~~~~~~~~~~~~~~~~~~~~~
In cases where a static symbolic method is not provided for its subsequent :class:`torch.autograd.Function` or
where a function to register ``prim::PythonOp`` as custom symbolic functions is not provided,
:func:`torch.onnx.export` tries to inline the graph that corresponds to that :class:`torch.autograd.Function` such that
this function is broken down into individual operators that were used within the function.
The export should be successful as long as these individual operators are supported. For example::

    class MyLogExp(torch.autograd.Function):
        @staticmethod
        def forward(ctx, input: torch.Tensor) -> torch.Tensor:
            ctx.save_for_backward(input)
            h = input.exp()
            return h.log().log()

There is no static symbolic method present for this model, yet it is exported as follows::

    graph(%input : Float(1, strides=[1], requires_grad=0, device=cpu)):
        %1 : float = onnx::Exp[](%input)
        %2 : float = onnx::Log[](%1)
        %3 : float = onnx::Log[](%2)
        return (%3)

If you need to avoid inlining of :class:`torch.autograd.Function`, you should export models with
``operator_export_type`` set to ``ONNX_FALLTHROUGH`` or ``ONNX_ATEN_FALLBACK``.

Custom operators
^^^^^^^^^^^^^^^^

If a model uses a custom operator implemented in C++ as described in
`Extending TorchScript with Custom C++ Operators <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_,
you can export it by following this example::

    from torch.onnx import symbolic_helper


    # Define custom symbolic function
    @symbolic_helper.parse_args("v", "v", "f", "i")
    def symbolic_foo_forward(g, input1, input2, attr1, attr2):
        return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


    # Register custom symbolic function
    torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


    class FooModel(torch.nn.Module):
        def __init__(self, attr1, attr2):
            super().__init__()
            self.attr1 = attr1
            self.attr2 = attr2

        def forward(self, input1, input2):
            # Calling custom op
            return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


    model = FooModel(attr1, attr2)
    torch.onnx.export(
        model,
        (example_input1, example_input1),
        "model.onnx",
        # only needed if you want to specify an opset version > 1.
        custom_opsets={"custom_domain": 2}
    )

You can export your model as one or a combination of many standard ONNX ops, or as a custom ONNX operator.

The example above exports it as a custom operator in the "custom_domain" opset.
When exporting a custom operator, you can specify the custom domain version using the
``custom_opsets`` dictionary at export. If not specified, the custom opset version defaults to 1.

The runtime that consumes the model needs to support the custom op. See
`Caffe2 custom ops <https://caffe2.ai/docs/custom-operators.html>`_,
`ONNX Runtime custom ops <https://onnxruntime.ai/docs/reference/operators/add-custom-op.html>`_,
or your runtime of choice's documentation.


Discovering all unconvertible ATen ops at once
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When export fails due to an unconvertible ATen op, there may in fact be more
than one such op but the error message only mentions the first. To discover
all of the unconvertible ops in one go you can::

    # prepare model, args, opset_version
    ...

    torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
        model, args, opset_version=opset_version
    )

    print(set(unconvertible_ops))

The set is approximated because some ops may be removed during the conversion
process and don't need to be converted. Some other ops may have partial support
that will fail conversion with particular inputs, but this should give you a
general idea of what ops are not supported. Please feel free to open Github Issues
for op support requests.

Frequently Asked Questions
--------------------------
Q: I have exported my LSTM model, but its input size seems to be fixed?

  The tracer records the shapes of the example inputs. If the model should accept
  inputs of dynamic shapes, set ``dynamic_axes`` when calling :func:`torch.onnx.export`.

Q: How to export models containing loops?

  See `Tracing vs Scripting`_.

Q: How to export models with primitive type inputs (e.g. int, float)?

  Support for primitive numeric type inputs was added in PyTorch 1.9.
  However, the exporter does not support models with str inputs.

Q: Does ONNX support implicit scalar datatype casting?

  The ONNX standard does not, but the exporter will try to handle that part.
  Scalars are exported as constant tensors.
  The exporter will figure out the right data type for scalars. In rare cases when it is unable
  to do so, you will need to manually specify the datatype with e.g. `dtype=torch.float32`.
  If you see any errors, please [create a GitHub issue](https://github.com/pytorch/pytorch/issues).

Q: Are lists of Tensors exportable to ONNX?

  Yes, for ``opset_version`` >= 11, since ONNX introduced the Sequence type in opset 11.


Contributing / developing
-------------------------
`Developer docs <https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter>`_.

Functions
---------
.. autofunction:: export
.. autofunction:: export_to_pretty_string
.. autofunction:: register_custom_op_symbolic
.. autofunction:: unregister_custom_op_symbolic
.. autofunction:: select_model_mode_for_export
.. autofunction:: is_in_onnx_export
.. autofunction:: enable_log
.. autofunction:: disable_log

Classes
-------

.. autosummary::
    :toctree: generated
    :nosignatures:
    :template: classtemplate.rst

    JitScalarType