File: test_unflatten.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (976 lines) | stat: -rw-r--r-- 33,958 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
# Owner(s): ["oncall: export"]
# flake8: noqa
import copy
import dataclasses
import unittest
from contextlib import contextmanager
from dataclasses import dataclass
from re import escape
from typing import Any, List

import torch
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import cond, map
from torch import Tensor
from torch._export.utils import (
    get_buffer,
    get_param,
    is_buffer,
    is_param,
    register_dataclass_as_pytree_node,
)
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
from torch.export import Constraint, Dim, export, FlatArgsAdapter, unflatten
from torch.export._trace import DEFAULT_EXPORT_DYNAMO_CONFIG
from torch.export.unflatten import _disable_interpreter
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import (
    find_library_location,
    IS_FBCODE,
    IS_MACOS,
    IS_SANDCASTLE,
    IS_WINDOWS,
    run_tests,
    skipIfTorchDynamo,
    TestCase,
)
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
from torch.utils._pytree import (
    LeafSpec,
    tree_flatten,
    tree_unflatten,
    TreeSpec,
    treespec_dumps,
    treespec_loads,
)


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestUnflatten(TestCase):
    def compare_outputs(self, eager, unflattened, args):
        orig_output = eager(*args)
        unflattened_output = unflattened(*args)
        self.assertTrue(torch.allclose(orig_output, unflattened_output))

    def test_unflatten_nested(self):
        class NestedChild(torch.nn.Module):
            def forward(self, x):
                return x / x

        class Child1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.nested = NestedChild()
                self.register_parameter(
                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = self.nested(x)
                return x + self.child1param

        class Child2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))

            def forward(self, x):
                return x - self.child2buffer

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Child1()
                self.bar = Child2()
                self.register_parameter(
                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = x * self.rootparam
                x = self.foo(x)
                x = self.bar(x)
                return x

        orig_eager = MyModule()
        export_module = export(orig_eager, (torch.rand(2, 3),), {})
        unflattened = unflatten(export_module)

        inputs = (torch.rand(2, 3),)

        # Compare the root modules and all submodules
        self.compare_outputs(orig_eager, unflattened, inputs)
        self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
        self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
        self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)

        # Check state dicts are equal
        orig_state_dict = orig_eager.state_dict()
        exported_state_dict = unflattened.state_dict()
        for name, value in orig_state_dict.items():
            self.assertTrue(torch.allclose(value, exported_state_dict[name]))

    def test_unflatten_buffer_mutation(self):
        class Child(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))

            def forward(self, x):
                self.child2buffer.add_(x)
                return x - self.child2buffer

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Child()
                self.register_parameter(
                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = self.foo(x)
                return x * self.rootparam

        eager_module = MyModule()
        export_module = export(eager_module, (torch.rand(2, 3),), {})
        unflattened_module = unflatten(export_module)

        # Buffer should look the same before and after one run
        eager_buffer = eager_module.foo.child2buffer
        unflattened_buffer = unflattened_module.foo.child2buffer
        self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))

        inputs = (torch.rand(2, 3),)
        eager_module(*inputs)
        unflattened_module(*inputs)
        self.assertTrue(torch.allclose(eager_buffer, unflattened_buffer))

    def test_unflatten_nested_access(self):
        class Child(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))

            def forward(self, x):
                return x - self.child2buffer

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Child()
                self.register_parameter(
                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = x + self.foo.child2buffer
                x = self.foo(x)
                return x

        eager_module = MyModule()
        export_module = export(eager_module, (torch.rand(2, 3),), {})
        unflattened_module = unflatten(export_module)

        inputs = (torch.rand(2, 3),)
        self.compare_outputs(eager_module, unflattened_module, inputs)

    def test_unflatten_shared_submodule(self):
        class Shared(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                layernorm = torch.nn.LayerNorm(10)
                self.sub_net = torch.nn.Sequential(
                    layernorm,
                    torch.nn.ReLU(),
                    layernorm,
                    torch.nn.ReLU(),
                )

            def forward(self, x):
                return self.sub_net(x)

        eager_module = Shared()
        inps = (torch.rand(10),)
        export_module = export(eager_module, inps, {})
        unflattened_module = unflatten(export_module)
        self.compare_outputs(eager_module, unflattened_module, inps)
        self.assertTrue(hasattr(unflattened_module, "sub_net"))
        for i in range(len(eager_module.sub_net)):
            self.assertTrue(hasattr(unflattened_module.sub_net, str(i)))
        self.assertEqual(
            id(getattr(unflattened_module.sub_net, "0")),
            id(getattr(unflattened_module.sub_net, "2")),
        )

    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
    @skipIfTorchDynamo("Non strict mode is not meant to run with dynamo")
    def test_unflatten_preserve_signature(self):
        class NestedChild(torch.nn.Module):
            def forward(self, zx, y):
                return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}

        class Child1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.nested = NestedChild()

            def forward(self, x, y):
                z = torch.ones_like(x)
                xw = self.nested((z, x), y={"key": y})
                return xw["w"] + z - xw["x"]

        class Child2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                return x - 1

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Child1()
                self.bar = Child2()

            def forward(self, x, y):
                x = self.foo(x, y)
                x = self.bar(x)
                return x

        orig_eager = MyModule()
        inps = torch.rand(2, 3), torch.rand(2, 3)
        for strict in [True, False]:
            export_module = export(
                orig_eager,
                inps,
                {},
                preserve_module_call_signature=("foo.nested",),
                strict=strict,
            )
            unflattened = unflatten(export_module)
            self.compare_outputs(export_module.module(), unflattened, inps)
            unflattened.foo.nested = NestedChild()
            self.compare_outputs(export_module.module(), unflattened, inps)

            # Test tree spec mismatched input
            orig_outs = export_module.module()(*inps)
            new_inps = *inps, torch.rand(2, 3)
            with self.assertRaisesRegex(
                TypeError,
                "There is no flat args adapter sepcified. Are you sure you are calling this with the right arguments?",
            ):
                unflattened(new_inps)

            # With flat args adapter
            class KeepTwoFlatArgsAdapter(FlatArgsAdapter):
                def adapt(
                    self,
                    target_spec: TreeSpec,
                    input_spec: TreeSpec,
                    input_args: List[Any],
                ) -> List[Any]:
                    while len(input_args) > 2:
                        input_args.pop(-1)
                    return input_args

            unflattened = unflatten(export_module, KeepTwoFlatArgsAdapter())
            new_outs = unflattened(*new_inps)
            self.assertTrue(torch.allclose(orig_outs, new_outs))

    def test_unflatten_param_list_dict(self):
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param_list = torch.nn.ParameterList()
                self.param_dict = torch.nn.ParameterDict()
                for i in range(2):
                    self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
                    self.param_dict[f"key_{i}"] = torch.nn.Parameter(
                        torch.randn((2, 3))
                    )

            def forward(self, x):
                for i in range(2):
                    x = x + self.param_list[i]
                    x = x + self.param_dict[f"key_{i}"]
                return x

        export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
        unflattened = unflatten(export_module)

        self.compare_outputs(
            export_module.module(), unflattened, (torch.randn((2, 3)),)
        )

    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
    def test_unflatten_preserve_with_unused_input(self):
        class M1(torch.nn.Module):
            def forward(self, x, a, b):
                return x + a, b

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m1 = M1()

            def forward(self, x, y):
                a, b = torch.topk(y, 2)
                return self.m1(x, a, b)[0]

        ep = torch.export.export(
            M(),
            (torch.randn(2), torch.randn(5)),
            preserve_module_call_signature=("m1",),
            strict=False,
        )
        ep.graph.eliminate_dead_code()
        unflattened = unflatten(ep)
        self.compare_outputs(ep.module(), unflattened, (torch.randn(2), torch.randn(5)))

    def test_unflatten_wrong_input(self):
        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param_list = torch.nn.ParameterList()
                self.param_dict = torch.nn.ParameterDict()
                for i in range(2):
                    self.param_list.append(torch.nn.Parameter(torch.randn((2, 3))))
                    self.param_dict[f"key_{i}"] = torch.nn.Parameter(
                        torch.randn((2, 3))
                    )

            def forward(self, x):
                a = x.sum()
                for i in range(2):
                    a = a + self.param_list[i].sum()
                    a = a + self.param_dict[f"key_{i}"].sum()
                return a

        export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
        with self.assertRaisesRegex(
            RuntimeError,
            escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
        ):
            export_module.module()(torch.randn(6, 6))

        unflattened = unflatten(export_module)
        with self.assertRaisesRegex(
            RuntimeError,
            escape("Expected input at *args[0].shape[0] to be equal to 2, but got 6"),
        ):
            unflattened(torch.randn(6, 6))

    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
    def test_unflatten_with_inplace_compile(self):
        class NestedChild(torch.nn.Module):
            def forward(self, x):
                return x / x

        class Child1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.nested = NestedChild()
                self.register_parameter(
                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = self.nested(x)
                return x + self.child1param

        class Child2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))

            def forward(self, x):
                return x - self.child2buffer

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Child1()
                self.bar = Child2()
                self.register_parameter(
                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = x * self.rootparam
                x = self.foo(x)
                x = self.bar(x)
                return x

        orig_eager = MyModule()
        export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
        unflattened = unflatten(export_module)

        # in-place compilation should work. Pass fullgraph to ensure no graph breaks.
        from torch._dynamo.backends.debugging import ExplainWithBackend

        eb = ExplainWithBackend("inductor")
        unflattened.foo.compile(backend=eb, fullgraph=True)
        inputs = (torch.randn(2, 3),)
        self.compare_outputs(orig_eager, unflattened, inputs)
        self.assertEqual(len(eb.graphs), 1)

        unflattened.compile()
        self.compare_outputs(orig_eager, unflattened, inputs)

    def test_fx_trace(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x, y):
                x = x[0] + x[1]
                x = x + y["foo"]
                return x

        orig_eager = MyModule()
        inputs = ((torch.rand(2, 3), torch.rand(2, 3)), {"foo": torch.rand(2, 3)})
        export_module = export(orig_eager, inputs, {})

        unflattened = unflatten(export_module)
        torch.fx.symbolic_trace(
            unflattened, concrete_args=(torch.fx.PH, torch.fx.PH, torch.fx.PH)
        )

    def test_double_nested_submodule(self):
        class SubSubMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                return x * x

        class SubMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.subsubmod = SubSubMod()

            def forward(self, x):
                return x - x

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.submod = SubMod()

            def forward(self, x):
                return x + self.submod.subsubmod(x)

        orig_eager = MyModule()
        export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
        unflattened = unflatten(export_module)

        inputs = (torch.rand(2, 3),)
        self.compare_outputs(orig_eager, unflattened, inputs)

    def test_unflatten_container_type(self):
        class Leaf(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, x):
                return self.linear(x)

        class Bar(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.leaf = Leaf()
                self.buffer = torch.nn.Buffer(torch.randn(4, 4))

            def forward(self, x, z):
                return self.buffer.sum() + self.leaf(x).sum() + z[0].sum() + z[1].sum()

        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bar = Bar()

            def forward(self, x, z):
                y = self.bar.buffer + x + z[0] + z[1]
                return self.bar(x, z) + y.sum()

        inp = (torch.randn(4, 4), [torch.randn(4, 4), torch.randn(4, 4)])
        mod = Foo()
        ep_strict = torch.export.export(mod, inp)
        ep_non_strict = torch.export.export(mod, inp, strict=False)

        gm_unflat_non_strict = unflatten(ep_non_strict)
        ep = torch.export.export(gm_unflat_non_strict, inp, strict=False)
        self.assertTrue(torch.allclose(ep.module()(*inp), mod(*inp)))

    def test_unflattened_module_nodes_has_meta_val(self):
        class SubMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x):
                return x + x, x * x

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.submod = SubMod()

            def forward(self, x):
                return x + sum(self.submod(x))

        orig_eager = MyModule()
        export_module = torch.export.export(orig_eager, (torch.rand(2, 3),), {})
        unflattened = unflatten(export_module)

        inputs = (torch.rand(2, 3),)
        self.compare_outputs(orig_eager, unflattened, inputs)

        def check_meta(gm):
            for n in gm.graph.nodes:
                if n.op == "output":
                    continue
                self.assertTrue(n.meta.get("val") is not None)

        for m in unflattened.modules():
            check_meta(m)

    def test_unflatten_requires_grad_param(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.p = torch.nn.Parameter(torch.ones(3, 3), requires_grad=False)

            def forward(self, x):
                return self.p + x

        with torch.device("meta"):
            mod = M()

        inputs = (torch.randn(3, 3, device="meta"),)
        ep = export(mod, inputs)
        unflattened = unflatten(ep)
        self.assertTrue(unflattened.state_dict()["p"].requires_grad is False)
        self.assertTrue(unflattened.p.requires_grad is False)

    def test_placeholder_and_get_attr_ordering_after_unflattened(self):
        class TransposeModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 1, 3, stride=2)

            def forward(self, x):
                x = self.conv(x)
                return x.transpose(0, 1)

        x = torch.randn(32, 3, 64, 64)
        exported_program = export(TransposeModule(), args=(x,))
        unflattened_module = unflatten(exported_program)

        # Check the inputs of the created call_module node are in order
        call_module_input_order = []
        for node in unflattened_module.graph.nodes:
            if node.op == "call_module":
                transpose_module = unflattened_module.get_submodule(node.target)
                for sub_node in transpose_module.graph.nodes:
                    if sub_node.op == "placeholder" or sub_node.op == "get_attr":
                        call_module_input_order.append(sub_node.op)
        self.assertEqual(
            call_module_input_order, ["placeholder", "get_attr", "get_attr"]
        )

    def test_unflatten_constant_tensor(self):
        class SubMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.initializer = 0.1

            def forward(self, x):
                return x + torch.tensor(self.initializer)

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.submod = SubMod()

            def forward(self, x):
                return x + self.submod(x)

        export_module = torch.export.export(Mod(), (torch.randn((2, 3)),))
        unflattened = unflatten(export_module)

        self.compare_outputs(
            export_module.module(), unflattened, (torch.randn((2, 3)),)
        )

    @skipIfTorchDynamo("custom objects not supported in dynamo yet")
    def test_unflatten_constant_obj(self):
        init_torchbind_implementations()

        @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
        class FakeFoo:
            def __init__(self, x: int, y: int):
                self.x = x
                self.y = y

            @classmethod
            def __obj_unflatten__(cls, flat_ctx):
                return cls(**dict(flat_ctx))

            def add_tensor(self, z):
                return (self.x + self.y) * z

        class SubMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)

            def forward(self, x):
                return x + self.attr.add_tensor(x)

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.submod = SubMod()

            def forward(self, x):
                return x + self.submod(x)

        with enable_torchbind_tracing():
            export_module = torch.export.export(
                Mod(), (torch.randn((2, 3)),), strict=False
            )
        unflattened = unflatten(export_module)

        self.compare_outputs(
            export_module.module(), unflattened, (torch.randn((2, 3)),)
        )

    # skip connection is not supported yet
    @unittest.expectedFailure
    def test_unflatten_skipped_call_module(self):
        class C(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return a.d(x.cos())

        class B(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.c = C()

            def forward(self, x):
                return self.c(x) + x

        class D(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return x.sin()

        class A(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.b = B()
                self.d = D()

            def forward(self, x):
                return self.b(x)

        a = A()

        # The call chain looks like this:
        # A -> B -> C -> A.d
        ep = torch.export.export(a, (torch.randn(3),), strict=False)
        unflattened = unflatten(ep)

    def test_nested_leaf_non_strict(self):
        class Leaf(torch.nn.Module):
            def forward(self, x):
                return x + 1

        class Nested(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.leaf = Leaf()

            def forward(self, x):
                return self.leaf(x) + 2

        class TopLevel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.nested = Nested()

            def forward(self, x):
                return self.nested(x) + 3

        ep = torch.export.export(
            TopLevel(),
            (torch.randn(3),),
            strict=False,
            preserve_module_call_signature=("nested",),
        )

        torch.export.unflatten(ep)

    def test_unflatten_submodule_ordering(self):
        class Module2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer = torch.nn.Buffer(torch.rand(3, 4))
                self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))

            def forward(self, x):
                return x + self.buffer + self.param

        class Module1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.buffer = torch.nn.Buffer(torch.rand(3, 4))
                self.register_parameter("param", torch.nn.Parameter(torch.rand(3, 4)))

            def forward(self, x):
                return x + self.buffer + self.param

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mod2 = Module2()
                self.mod3 = self.mod2
                self.mod1 = Module1()

            def forward(self, x):
                return self.mod3(self.mod2(self.mod1(x)))

        mod = Module()

        ep = torch.export.export(mod, (torch.randn(3, 4),))

        unflattened = torch.export.unflatten(ep)
        fqn_list = [x for x, _ in unflattened.named_modules(remove_duplicate=False)]
        self.assertEqual(len(fqn_list), 4)
        self.assertEqual(
            [x for x, _ in mod.named_modules(remove_duplicate=False)],
            fqn_list,
        )

    def test_duplicate_placeholder(self):
        N, C, H, W = 1, 2, 2, 3

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                layer = torch.nn.LayerNorm([C, H, W])
                self.norms = torch.nn.ModuleList(
                    [
                        layer,  # reuse layer norm
                        layer,
                        layer,
                    ]
                )

            def forward(self, input_):
                for i in range(len(self.norms)):
                    output = self.norms[i](input_)
                    input_ = output
                return output

        mod = MyModule()
        input_ = torch.randn(N, C, H, W)

        ep_strict = export(copy.deepcopy(mod), (input_,), strict=True)
        umod = unflatten(ep_strict)
        self.assertTrue(torch.allclose(umod(input_), mod(input_)))

        ep_non_strict = export(copy.deepcopy(mod), (input_,), strict=False)
        umod = unflatten(ep_non_strict)
        self.assertTrue(torch.allclose(umod(input_), mod(input_)))

    def test_simple_alias(self):
        # handle weight sharing, check tensor ids after unflattening
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                # alias param
                self.bias = torch.nn.Parameter(torch.randn(4))
                self.m = torch.nn.Linear(4, 4)
                self.m.bias = self.bias

            def forward(self, x):
                return self.m(x) + self.bias

        m = Foo()
        inps = (torch.randn(4, 4),)
        ep = export(m, inps)
        unep = unflatten(ep)
        self.assertTrue(id(unep.m.bias) == id(unep.bias))

        # handle aliasing where one alias is unused
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = torch.nn.Parameter(torch.randn(4))
                self.m = torch.nn.Linear(4, 4)
                self.m.bias = (
                    self.bias
                )  # self.bias is unused, aliasing should be handled

            def forward(self, x):
                return self.m(x)

        m = Foo()
        inps = (torch.randn(4, 4),)
        ep = export(m, inps)
        unep = unflatten(ep)
        self.assertTrue(torch.allclose(unep(*inps), m(*inps)))

    def test_attr_as_submod_input(self):
        class layer(torch.nn.Module):
            def forward(self, x, const) -> torch.Tensor:
                return x + const

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.const = torch.nn.Buffer(torch.ones(4, 8))
                self.layers = torch.nn.ModuleList([layer() for _ in range(2)])

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                for layer in self.layers:
                    x = layer(x, self.const)
                return x

        mod = M()
        x = torch.randn(4, 8)
        ep = export(mod, (x,))
        unflattened = unflatten(ep)
        torch.testing.assert_close(unflattened(x), mod(x))

    def test_dedup_sym_size(self):
        # Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
        # but only one copy of sym_size is created in the initial export graph.
        # For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
        # but for m2 floordiv should be passed in as a placeholder.
        # Test that this is preserved, and the unflattened module runs correctly.
        class M1(torch.nn.Module):
            def forward(self, x, y):
                d = x.size(0) // 2
                return y[:d]

        class M2(torch.nn.Module):
            def forward(self, x, y):
                d = x.size(0) // 2
                return y[:d]

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m1 = M1()
                self.m2 = M2()

            def forward(self, x, y):
                d = x.size(0) // 2
                m1_res = self.m1(x, y)
                m2_res = self.m2(x, y)
                return y[d:] + m1_res + m2_res

        inputs = (torch.ones(10), torch.ones(10))
        d_ = torch.export.Dim("foo", max=2048)
        d = 2 * d_
        ep = torch.export.export(
            M(),
            inputs,
            dynamic_shapes=((d,), (d,)),
            strict=False,
            preserve_module_call_signature=("m1",),
        )
        unflat = unflatten(ep)
        unflat(*inputs)

        fn_count_sym_size = lambda graph: [node.target for node in graph.nodes].count(
            torch.ops.aten.sym_size.int
        )
        self.assertEqual(fn_count_sym_size(unflat.graph), 1)
        self.assertEqual(fn_count_sym_size(unflat.m1.graph), 1)
        self.assertEqual(fn_count_sym_size(unflat.m2.graph), 0)

    def test_unflatten_eager(self):
        class NestedChild(torch.nn.Module):
            def forward(self, x):
                return x / x

        class Child1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.nested = NestedChild()
                self.register_parameter(
                    "child1param", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = self.nested(x)
                return x + self.child1param

        class Child2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.child2buffer = torch.nn.Buffer(torch.ones(2, 3))

            def forward(self, x):
                return x - self.child2buffer

        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Child1()
                self.bar = Child2()
                self.register_parameter(
                    "rootparam", torch.nn.Parameter(torch.ones(2, 3))
                )

            def forward(self, x):
                x = x * self.rootparam
                x = self.foo(x)
                x = self.bar(x)
                return x

        orig_eager = MyModule()
        export_module = export(orig_eager, (torch.rand(2, 3),), {})
        with _disable_interpreter():
            unflattened = unflatten(export_module)

        self.assertEqual(unflattened._run_with_interpreter, False)
        self.assertEqual(unflattened.foo._run_with_interpreter, False)

        inputs = (torch.rand(2, 3),)

        # Compare the root modules and all submodules
        self.compare_outputs(orig_eager, unflattened, inputs)
        self.compare_outputs(orig_eager.foo, unflattened.foo, inputs)
        self.compare_outputs(orig_eager.bar, unflattened.bar, inputs)
        self.compare_outputs(orig_eager.foo.nested, unflattened.foo.nested, inputs)

        # Check state dicts are equal
        orig_state_dict = orig_eager.state_dict()
        exported_state_dict = unflattened.state_dict()
        for name, value in orig_state_dict.items():
            self.assertTrue(torch.allclose(value, exported_state_dict[name]))

        # Check composability with symbolic trace, as torchrec ddp uses symbolic
        # tracer
        symbolic_traced = torch.fx.symbolic_trace(unflattened, concrete_args=inputs)
        self.assertTrue(torch.allclose(orig_eager(*inputs), symbolic_traced(*inputs)))

        # torch.compile submodule
        unflattened.foo = torch.compile(unflattened.foo, fullgraph=True)
        self.compare_outputs(orig_eager, unflattened, inputs)


if __name__ == "__main__":
    run_tests()