File: export.rst

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (1151 lines) | stat: -rw-r--r-- 44,251 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
.. _torch.export:

torch.export
=====================

.. warning::
    This feature is a prototype under active development and there WILL BE
    BREAKING CHANGES in the future.


Overview
--------

:func:`torch.export.export` takes an arbitrary Python callable (a
:class:`torch.nn.Module`, a function or a method) and produces a traced graph
representing only the Tensor computation of the function in an Ahead-of-Time
(AOT) fashion, which can subsequently be executed with different outputs or
serialized.

::

    import torch
    from torch.export import export

    class Mod(torch.nn.Module):
        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
            a = torch.sin(x)
            b = torch.cos(y)
            return a + b

    example_args = (torch.randn(10, 10), torch.randn(10, 10))

    exported_program: torch.export.ExportedProgram = export(
        Mod(), args=example_args
    )
    print(exported_program)

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
                # code: a = torch.sin(x)
                sin: "f32[10, 10]" = torch.ops.aten.sin.default(x)

                # code: b = torch.cos(y)
                cos: "f32[10, 10]" = torch.ops.aten.cos.default(y)

                # code: return a + b
                add: f32[10, 10] = torch.ops.aten.add.Tensor(sin, cos)
                return (add,)

        Graph signature:
            ExportGraphSignature(
                input_specs=[
                    InputSpec(
                        kind=<InputKind.USER_INPUT: 1>,
                        arg=TensorArgument(name='x'),
                        target=None,
                        persistent=None
                    ),
                    InputSpec(
                        kind=<InputKind.USER_INPUT: 1>,
                        arg=TensorArgument(name='y'),
                        target=None,
                        persistent=None
                    )
                ],
                output_specs=[
                    OutputSpec(
                        kind=<OutputKind.USER_OUTPUT: 1>,
                        arg=TensorArgument(name='add'),
                        target=None
                    )
                ]
            )
        Range constraints: {}

``torch.export`` produces a clean intermediate representation (IR) with the
following invariants. More specifications about the IR can be found
:ref:`here <export.ir_spec>`.

* **Soundness**: It is guaranteed to be a sound representation of the original
  program, and maintains the same calling conventions of the original program.

* **Normalized**: There are no Python semantics within the graph. Submodules
  from the original programs are inlined to form one fully flattened
  computational graph.

* **Graph properties**: The graph is purely functional, meaning it does not
  contain operations with side effects such as mutations or aliasing. It does
  not mutate any intermediate values, parameters, or buffers.

* **Metadata**: The graph contains metadata captured during tracing, such as a
  stacktrace from user's code.

Under the hood, ``torch.export`` leverages the following latest technologies:

* **TorchDynamo (torch._dynamo)** is an internal API that uses a CPython feature
  called the Frame Evaluation API to safely trace PyTorch graphs. This
  provides a massively improved graph capturing experience, with much fewer
  rewrites needed in order to fully trace the PyTorch code.

* **AOT Autograd** provides a functionalized PyTorch graph and ensures the graph
  is decomposed/lowered to the ATen operator set.

* **Torch FX (torch.fx)** is the underlying representation of the graph,
  allowing flexible Python-based transformations.


Existing frameworks
^^^^^^^^^^^^^^^^^^^

:func:`torch.compile` also utilizes the same PT2 stack as ``torch.export``, but
is slightly different:

* **JIT vs. AOT**: :func:`torch.compile` is a JIT compiler whereas
  which is not intended to be used to produce compiled artifacts outside of
  deployment.

* **Partial vs. Full Graph Capture**: When :func:`torch.compile` runs into an
  untraceable part of a model, it will "graph break" and fall back to running
  the program in the eager Python runtime. In comparison, ``torch.export`` aims
  to get a full graph representation of a PyTorch model, so it will error out
  when something untraceable is reached. Since ``torch.export`` produces a full
  graph disjoint from any Python features or runtime, this graph can then be
  saved, loaded, and run in different environments and languages.

* **Usability tradeoff**: Since :func:`torch.compile` is able to fallback to the
  Python runtime whenever it reaches something untraceable, it is a lot more
  flexible. ``torch.export`` will instead require users to provide more
  information or rewrite their code to make it traceable.

Compared to :func:`torch.fx.symbolic_trace`, ``torch.export`` traces using
TorchDynamo which operates at the Python bytecode level, giving it the ability
to trace arbitrary Python constructs not limited by what Python operator
overloading supports. Additionally, ``torch.export`` keeps fine-grained track of
tensor metadata, so that conditionals on things like tensor shapes do not
fail tracing. In general, ``torch.export`` is expected to work on more user
programs, and produce lower-level graphs (at the ``torch.ops.aten`` operator
level). Note that users can still use :func:`torch.fx.symbolic_trace` as a
preprocessing step before ``torch.export``.

Compared to :func:`torch.jit.script`, ``torch.export`` does not capture Python
control flow or data structures, but it supports more Python language features
than TorchScript (as it is easier to have comprehensive coverage over Python
bytecodes). The resulting graphs are simpler and only have straight line control
flow (except for explicit control flow operators).

Compared to :func:`torch.jit.trace`, ``torch.export`` is sound: it is able to
trace code that performs integer computation on sizes and records all of the
side-conditions necessary to show that a particular trace is valid for other
inputs.


Exporting a PyTorch Model
-------------------------

An Example
^^^^^^^^^^

The main entrypoint is through :func:`torch.export.export`, which takes a
callable (:class:`torch.nn.Module`, function, or method) and sample inputs, and
captures the computation graph into an :class:`torch.export.ExportedProgram`. An
example:

::

    import torch
    from torch.export import export

    # Simple module for demonstration
    class M(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(
                in_channels=3, out_channels=16, kernel_size=3, padding=1
            )
            self.relu = torch.nn.ReLU()
            self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

        def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
            a = self.conv(x)
            a.add_(constant)
            return self.maxpool(self.relu(a))

    example_args = (torch.randn(1, 3, 256, 256),)
    example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

    exported_program: torch.export.ExportedProgram = export(
        M(), args=example_args, kwargs=example_kwargs
    )
    print(exported_program)

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
                # code: a = self.conv(x)
                conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1])

                # code: a.add_(constant)
                add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant)

                # code: return self.maxpool(self.relu(a))
                relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_)
                max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3])
                return (max_pool2d,)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_weight'),
                    target='conv.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_bias'),
                    target='conv.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='constant'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='max_pool2d'),
                    target=None
                )
            ]
        )
    Range constraints: {}

Inspecting the ``ExportedProgram``, we can note the following:

* The :class:`torch.fx.Graph` contains the computation graph of the original
  program, along with records of the original code for easy debugging.

* The graph contains only ``torch.ops.aten`` operators found `here <https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml>`__
  and custom operators, and is fully functional, without any inplace operators
  such as ``torch.add_``.

* The parameters (weight and bias to conv) are lifted as inputs to the graph,
  resulting in no ``get_attr`` nodes in the graph, which previously existed in
  the result of :func:`torch.fx.symbolic_trace`.

* The :class:`torch.export.ExportGraphSignature` models the input and output
  signature, along with specifying which inputs are parameters.

* The resulting shape and dtype of tensors produced by each node in the graph is
  noted. For example, the ``convolution`` node will result in a tensor of dtype
  ``torch.float32`` and shape (1, 16, 256, 256).


.. _Non-Strict Export:

Non-Strict Export
^^^^^^^^^^^^^^^^^

In PyTorch 2.3, we introduced a new mode of tracing called **non-strict mode**.
It's still going through hardening, so if you run into any issues, please file
them to Github with the "oncall: export" tag.

In *non-strict mode*, we trace through the program using the Python interpreter.
Your code will execute exactly as it would in eager mode; the only difference is
that all Tensor objects will be replaced by ProxyTensors, which will record all
their operations into a graph.

In *strict* mode, which is currently the default, we first trace through the
program using TorchDynamo, a bytecode analysis engine. TorchDynamo does not
actually execute your Python code. Instead, it symbolically analyzes it and
builds a graph based on the results. This analysis allows torch.export to
provide stronger guarantees about safety, but not all Python code is supported.

An example of a case where one might want to use non-strict mode is if you run
into a unsupported TorchDynamo feature that might not be easily solved, and you
know the python code is not exactly needed for computation. For example:

::

    import contextlib
    import torch

    class ContextManager():
        def __init__(self):
            self.count = 0
        def __enter__(self):
            self.count += 1
        def __exit__(self, exc_type, exc_value, traceback):
            self.count -= 1

    class M(torch.nn.Module):
        def forward(self, x):
            with ContextManager():
                return x.sin() + x.cos()

    export(M(), (torch.ones(3, 3),), strict=False)  # Non-strict traces successfully
    export(M(), (torch.ones(3, 3),))  # Strict mode fails with torch._dynamo.exc.Unsupported: ContextManager

In this example, the first call using non-strict mode (through the
``strict=False`` flag) traces successfully whereas the second call using strict
mode (default) results with a failure, where TorchDynamo is unable to support
context managers. One option is to rewrite the code (see :ref:`Limitations of torch.export <Limitations of
torch.export>`), but seeing as the context manager does not affect the tensor
computations in the model, we can go with the non-strict mode's result.


.. _Training Export:

Export for Training and Inference
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In PyTorch 2.5, we introduced a new API called :func:`export_for_training`.
It's still going through hardening, so if you run into any issues, please file
them to Github with the "oncall: export" tag.

In this API, we produce the most generic IR that contains all ATen operators
(including both functional and non-functional) which can be used to train in
eager PyTorch Autograd. This API is intended for eager training use cases such as PT2 Quantization
and will soon be the default IR of torch.export.export. To read further about
the motivation behind this change, please refer to
https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206

When this API is combined with :func:`run_decompositions()`, you should be able to get inference IR with
any desired decomposition behavior.

To show some examples:

::

    class ConvBatchnorm(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv = torch.nn.Conv2d(1, 3, 1, 1)
            self.bn = torch.nn.BatchNorm2d(3)

        def forward(self, x):
            x = self.conv(x)
            x = self.bn(x)
            return (x,)

    mod = ConvBatchnorm()
    inp = torch.randn(1, 1, 3, 3)

    ep_for_training = torch.export.export_for_training(mod, (inp,))
    print(ep_for_training)

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
                conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
                add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1)
                batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True)
                return (batch_norm,)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_weight'),
                    target='conv.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_bias'),
                    target='conv.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_bn_weight'),
                    target='bn.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_bn_bias'),
                    target='bn.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_running_mean'),
                    target='bn.running_mean',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_running_var'),
                    target='bn.running_var',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_num_batches_tracked'),
                    target='bn.num_batches_tracked',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='batch_norm'),
                    target=None
                )
            ]
        )
    Range constraints: {}


From the above output, you can see that :func:`export_for_training` produces pretty much the same ExportedProgram
as :func:`export` except for the operators in the graph. You can see that we captured batch_norm in the most general
form. This op is non-functional and will be lowered to different ops when running inference.

You can also go from this IR to an inference IR via :func:`run_decompositions` with arbitrary customizations.

::

    # Lower to core aten inference IR, but keep conv2d
    decomp_table = torch.export.default_decompositions()
    del decomp_table[torch.ops.aten.conv2d.default]
    ep_for_inference = ep_for_training.run_decompositions(decomp_table)

    print(ep_for_inference)

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
                conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias)
                add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
                _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
                getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
                getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
                getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]
                return (getitem_3, getitem_4, add, getitem)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_weight'),
                    target='conv.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_bias'),
                    target='conv.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_bn_weight'),
                    target='bn.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_bn_bias'),
                    target='bn.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_running_mean'),
                    target='bn.running_mean',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_running_var'),
                    target='bn.running_var',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_num_batches_tracked'),
                    target='bn.num_batches_tracked',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.BUFFER_MUTATION: 3>,
                    arg=TensorArgument(name='getitem_3'),
                    target='bn.running_mean'
                ),
                OutputSpec(
                    kind=<OutputKind.BUFFER_MUTATION: 3>,
                    arg=TensorArgument(name='getitem_4'),
                    target='bn.running_var'
                ),
                OutputSpec(
                    kind=<OutputKind.BUFFER_MUTATION: 3>,
                    arg=TensorArgument(name='add'),
                    target='bn.num_batches_tracked'
                ),
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='getitem'),
                    target=None
                )
            ]
        )
    Range constraints: {}

Here you can see that we kept ``conv2d`` op in the IR while decomposing the rest. Now the IR is a functional IR
containing core aten operators except for ``conv2d``.

You can do even more customization by directly registering your chosen decomposition behaviors.

You can do even more customizations by directly registering custom decomp behaviour

::

    # Lower to core aten inference IR, but customize conv2d
    decomp_table = torch.export.default_decompositions()

    def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
        return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

    decomp_table[torch.ops.aten.conv2d.default] = my_awesome_conv2d_function
    ep_for_inference = ep_for_training.run_decompositions(decomp_table)

    print(ep_for_inference)

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
                convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
                mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2)
                add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1)
                _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05)
                getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
                getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
                getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];
                return (getitem_3, getitem_4, add, getitem)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_weight'),
                    target='conv.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_conv_bias'),
                    target='conv.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_bn_weight'),
                    target='bn.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_bn_bias'),
                    target='bn.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_running_mean'),
                    target='bn.running_mean',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_running_var'),
                    target='bn.running_var',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.BUFFER: 3>,
                    arg=TensorArgument(name='b_bn_num_batches_tracked'),
                    target='bn.num_batches_tracked',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.BUFFER_MUTATION: 3>,
                    arg=TensorArgument(name='getitem_3'),
                    target='bn.running_mean'
                ),
                OutputSpec(
                    kind=<OutputKind.BUFFER_MUTATION: 3>,
                    arg=TensorArgument(name='getitem_4'),
                    target='bn.running_var'
                ),
                OutputSpec(
                    kind=<OutputKind.BUFFER_MUTATION: 3>,
                    arg=TensorArgument(name='add'),
                    target='bn.num_batches_tracked'
                ),
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='getitem'),
                    target=None
                )
            ]
    )
    Range constraints: {}


Expressing Dynamism
^^^^^^^^^^^^^^^^^^^

By default ``torch.export`` will trace the program assuming all input shapes are
**static**, and specializing the exported program to those dimensions. However,
some dimensions, such as a batch dimension, can be dynamic and vary from run to
run. Such dimensions must be specified by using the
:func:`torch.export.Dim` API to create them and by passing them into
:func:`torch.export.export` through the ``dynamic_shapes`` argument. An example:

::

    import torch
    from torch.export import Dim, export

    class M(torch.nn.Module):
        def __init__(self):
            super().__init__()

            self.branch1 = torch.nn.Sequential(
                torch.nn.Linear(64, 32), torch.nn.ReLU()
            )
            self.branch2 = torch.nn.Sequential(
                torch.nn.Linear(128, 64), torch.nn.ReLU()
            )
            self.buffer = torch.ones(32)

        def forward(self, x1, x2):
            out1 = self.branch1(x1)
            out2 = self.branch2(x2)
            return (out1 + self.buffer, out2)

    example_args = (torch.randn(32, 64), torch.randn(32, 128))

    # Create a dynamic batch size
    batch = Dim("batch")
    # Specify that the first dimension of each input is that batch size
    dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

    exported_program: torch.export.ExportedProgram = export(
        M(), args=example_args, dynamic_shapes=dynamic_shapes
    )
    print(exported_program)

.. code-block::

    ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s0, 64]", x2: "f32[s0, 128]"):

             # code: out1 = self.branch1(x1)
            linear: "f32[s0, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias)
            relu: "f32[s0, 32]" = torch.ops.aten.relu.default(linear)

             # code: out2 = self.branch2(x2)
            linear_1: "f32[s0, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias)
            relu_1: "f32[s0, 64]" = torch.ops.aten.relu.default(linear_1)

             # code: return (out1 + self.buffer, out2)
            add: "f32[s0, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer)
            return (add, relu_1)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_branch1_0_weight'),
                    target='branch1.0.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_branch1_0_bias'),
                    target='branch1.0.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_branch2_0_weight'),
                    target='branch2.0.weight',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.PARAMETER: 2>,
                    arg=TensorArgument(name='p_branch2_0_bias'),
                    target='branch2.0.bias',
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.CONSTANT_TENSOR: 4>,
                    arg=TensorArgument(name='c_buffer'),
                    target='buffer',
                    persistent=True
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x1'),
                    target=None,
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x2'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='add'),
                    target=None
                ),
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='relu_1'),
                    target=None
                )
            ]
        )
    Range constraints: {s0: VR[0, int_oo]}

Some additional things to note:

* Through the :func:`torch.export.Dim` API and the ``dynamic_shapes`` argument, we specified the first
  dimension of each input to be dynamic. Looking at the inputs ``x1`` and
  ``x2``, they have a symbolic shape of (s0, 64) and (s0, 128), instead of
  the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs.
  ``s0`` is a symbol representing that this dimension can be a range
  of values.

* ``exported_program.range_constraints`` describes the ranges of each symbol
  appearing in the graph. In this case, we see that ``s0`` has the range
  [0, int_oo]. For technical reasons that are difficult to explain here, they are
  assumed to be not 0 or 1. This is not a bug, and does not necessarily mean
  that the exported program will not work for dimensions 0 or 1. See
  `The 0/1 Specialization Problem <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ#heading=h.ez923tomjvyk>`_
  for an in-depth discussion of this topic.


We can also specify more expressive relationships between input shapes, such as
where a pair of shapes might differ by one, a shape might be double of
another, or a shape is even. An example:

::

    class M(torch.nn.Module):
        def forward(self, x, y):
            return x + y[1:]

    x, y = torch.randn(5), torch.randn(6)
    dimx = torch.export.Dim("dimx", min=3, max=6)
    dimy = dimx + 1

    exported_program = torch.export.export(
        M(), (x, y), dynamic_shapes=({0: dimx}, {0: dimy}),
    )
    print(exported_program)

.. code-block::

    ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[s0]", y: "f32[s0 + 1]"):
            # code: return x + y[1:]
            slice_1: "f32[s0]" = torch.ops.aten.slice.Tensor(y, 0, 1, 9223372036854775807)
            add: "f32[s0]" = torch.ops.aten.add.Tensor(x, slice_1)
            return (add,)

    Graph signature:
        ExportGraphSignature(
            input_specs=[
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='x'),
                    target=None,
                    persistent=None
                ),
                InputSpec(
                    kind=<InputKind.USER_INPUT: 1>,
                    arg=TensorArgument(name='y'),
                    target=None,
                    persistent=None
                )
            ],
            output_specs=[
                OutputSpec(
                    kind=<OutputKind.USER_OUTPUT: 1>,
                    arg=TensorArgument(name='add'),
                    target=None
                )
            ]
        )
    Range constraints: {s0: VR[3, 6], s0 + 1: VR[4, 7]}

Some things to note:

* By specifying ``{0: dimx}`` for the first input, we see that the resulting
  shape of the first input is now dynamic, being ``[s0]``. And now by specifying
  ``{0: dimy}`` for the second input, we see that the resulting shape of the
  second input is also dynamic. However, because we expressed ``dimy = dimx + 1``,
  instead of ``y``'s shape containing a new symbol, we see that it is
  now being represented with the same symbol used in ``x``, ``s0``. We can
  see that relationship of ``dimy = dimx + 1`` is being shown through ``s0 + 1``.

* Looking at the range constraints, we see that ``s0`` has the range [3, 6],
  which is specified initially, and we can see that ``s0 + 1`` has the solved
  range of [4, 7].


Serialization
^^^^^^^^^^^^^

To save the ``ExportedProgram``, users can use the :func:`torch.export.save` and
:func:`torch.export.load` APIs. A convention is to save the ``ExportedProgram``
using a ``.pt2`` file extension.

An example:

::

    import torch
    import io

    class MyModule(torch.nn.Module):
        def forward(self, x):
            return x + 10

    exported_program = torch.export.export(MyModule(), torch.randn(5))

    torch.export.save(exported_program, 'exported_program.pt2')
    saved_exported_program = torch.export.load('exported_program.pt2')


Specializations
^^^^^^^^^^^^^^^

A key concept in understanding the behavior of ``torch.export`` is the
difference between *static* and *dynamic* values.

A *dynamic* value is one that can change from run to run. These behave like
normal arguments to a Python function—you can pass different values for an
argument and expect your function to do the right thing. Tensor *data* is
treated as dynamic.


A *static* value is a value that is fixed at export time and cannot change
between executions of the exported program. When the value is encountered during
tracing, the exporter will treat it as a constant and hard-code it into the
graph.

When an operation is performed (e.g. ``x + y``) and all inputs are static, then
the output of the operation will be directly hard-coded into the graph, and the
operation won’t show up (i.e. it will get constant-folded).

When a value has been hard-coded into the graph, we say that the graph has been
*specialized* to that value.

The following values are static:

Input Tensor Shapes
~~~~~~~~~~~~~~~~~~~

By default, ``torch.export`` will trace the program specializing on the input
tensors' shapes, unless a dimension is specified as dynamic via the
``dynamic_shapes`` argument to ``torch.export``. This means that if there exists
shape-dependent control flow, ``torch.export`` will specialize on the branch
that is being taken with the given sample inputs. For example:

::

    import torch
    from torch.export import export

    class Mod(torch.nn.Module):
        def forward(self, x):
            if x.shape[0] > 5:
                return x + 1
            else:
                return x - 1

    example_inputs = (torch.rand(10, 2),)
    exported_program = export(Mod(), example_inputs)
    print(exported_program)

.. code-block::

    ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 2]"):
            # code: return x + 1
            add: "f32[10, 2]" = torch.ops.aten.add.Tensor(x, 1)
            return (add,)

The conditional of (``x.shape[0] > 5``) does not appear in the
``ExportedProgram`` because the example inputs have the static
shape of (10, 2). Since ``torch.export`` specializes on the inputs' static
shapes, the else branch (``x - 1``) will never be reached. To preserve the dynamic
branching behavior based on the shape of a tensor in the traced graph,
:func:`torch.export.Dim` will need to be used to specify the dimension
of the input tensor (``x.shape[0]``) to be dynamic, and the source code will
need to be :ref:`rewritten <Data/Shape-Dependent Control Flow>`.

Note that tensors that are part of the module state (e.g. parameters and
buffers) always have static shapes.

Python Primitives
~~~~~~~~~~~~~~~~~

``torch.export`` also specializes on Python primtivies,
such as ``int``, ``float``, ``bool``, and ``str``. However they do have dynamic
variants such as ``SymInt``, ``SymFloat``, and ``SymBool``.

For example:

::

    import torch
    from torch.export import export

    class Mod(torch.nn.Module):
        def forward(self, x: torch.Tensor, const: int, times: int):
            for i in range(times):
                x = x + const
            return x

    example_inputs = (torch.rand(2, 2), 1, 3)
    exported_program = export(Mod(), example_inputs)
    print(exported_program)

.. code-block::

    ExportedProgram:
        class GraphModule(torch.nn.Module):
            def forward(self, x: "f32[2, 2]", const, times):
                # code: x = x + const
                add: "f32[2, 2]" = torch.ops.aten.add.Tensor(x, 1)
                add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 1)
                add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 1)
                return (add_2,)

Because integers are specialized, the ``torch.ops.aten.add.Tensor`` operations
are all computed with the hard-coded constant ``1``, rather than ``const``. If
a user passes a different value for ``const`` at runtime, like 2, than the one used
during export time, 1, this will result in an error.
Additionally, the ``times`` iterator used in the ``for`` loop is also "inlined"
in the graph through the 3 repeated ``torch.ops.aten.add.Tensor`` calls, and the
input ``times`` is never used.

Python Containers
~~~~~~~~~~~~~~~~~

Python containers (``List``, ``Dict``, ``NamedTuple``, etc.) are considered to
have static structure.


.. _Limitations of torch.export:

Limitations of torch.export
---------------------------

Graph Breaks
^^^^^^^^^^^^

As ``torch.export`` is a one-shot process for capturing a computation graph from
a PyTorch program, it might ultimately run into untraceable parts of programs as
it is nearly impossible to support tracing all PyTorch and Python features. In
the case of ``torch.compile``, an unsupported operation will cause a "graph
break" and the unsupported operation will be run with default Python evaluation.
In contrast, ``torch.export`` will require users to provide additional
information or rewrite parts of their code to make it traceable. As the
tracing is based on TorchDynamo, which evaluates at the Python
bytecode level, there will be significantly fewer rewrites required compared to
previous tracing frameworks.

When a graph break is encountered, :ref:`ExportDB <torch.export_db>` is a great
resource for learning about the kinds of programs that are supported and
unsupported, along with ways to rewrite programs to make them traceable.

An option to get past dealing with this graph breaks is by using
:ref:`non-strict export <Non-Strict Export>`

.. _Data/Shape-Dependent Control Flow:

Data/Shape-Dependent Control Flow
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Graph breaks can also be encountered on data-dependent control flow (``if
x.shape[0] > 2``) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using
special control flow operators. Currently, we support :ref:`torch.cond <cond>`
to express if-else like control flow (more coming soon!).

Missing Fake/Meta/Abstract Kernels for Operators
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When tracing, a FakeTensor kernel (aka meta kernel, abstract impl) is
required for all operators. This is used to reason about the input/output shapes
for this operator.

Please see :func:`torch.library.register_fake` for more details.

In the unfortunate case where your model uses an ATen operator that is does not
have a FakeTensor kernel implementation yet, please file an issue.


Read More
---------

.. toctree::
   :caption: Additional Links for Export Users
   :maxdepth: 1

   export.ir_spec
   torch.compiler_transformations
   torch.compiler_ir
   generated/exportdb/index
   cond

.. toctree::
   :caption: Deep Dive for PyTorch Developers
   :maxdepth: 1

   torch.compiler_dynamo_overview
   torch.compiler_dynamo_deepdive
   torch.compiler_dynamic_shapes
   torch.compiler_fake_tensor


API Reference
-------------

.. automodule:: torch.export
.. autofunction:: export
.. autofunction:: save
.. autofunction:: load
.. autofunction:: register_dataclass
.. autofunction:: torch.export.dynamic_shapes.Dim
.. autofunction:: torch.export.exported_program.default_decompositions
.. autofunction:: dims
.. autoclass:: torch.export.dynamic_shapes.ShapesCollection

    .. automethod:: dynamic_shapes

.. autofunction:: torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes
.. autoclass:: Constraint
.. autoclass:: ExportedProgram

    .. automethod:: module
    .. automethod:: buffers
    .. automethod:: named_buffers
    .. automethod:: parameters
    .. automethod:: named_parameters
    .. automethod:: run_decompositions

.. autoclass:: ExportBackwardSignature
.. autoclass:: ExportGraphSignature
.. autoclass:: ModuleCallSignature
.. autoclass:: ModuleCallEntry


.. automodule:: torch.export.decomp_utils
.. autoclass:: CustomDecompTable

    .. automethod:: copy
    .. automethod:: items
    .. automethod:: keys
    .. automethod:: materialize
    .. automethod:: pop
    .. automethod:: update

.. automodule:: torch.export.exported_program
.. automodule:: torch.export.graph_signature
.. autoclass:: InputKind
.. autoclass:: InputSpec
.. autoclass:: OutputKind
.. autoclass:: OutputSpec
.. autoclass:: SymIntArgument
.. autoclass:: SymBoolArgument
.. autoclass:: SymFloatArgument
.. autoclass:: ExportGraphSignature

    .. automethod:: replace_all_uses
    .. automethod:: get_replace_hook

.. autoclass:: torch.export.graph_signature.CustomObjArgument

.. py:module:: torch.export.dynamic_shapes

.. automodule:: torch.export.unflatten
    :members:

.. automodule:: torch.export.custom_obj

.. automodule:: torch.export.experimental
.. automodule:: torch.export.passes
.. autofunction:: torch.export.passes.move_to_device_pass