File: test_db.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (107 lines) | stat: -rw-r--r-- 3,374 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Owner(s): ["oncall: export"]

import copy
import unittest

import torch._dynamo as torchdynamo
from torch._export.db.case import ExportCase, SupportLevel
from torch._export.db.examples import (
    filter_examples_by_support_level,
    get_rewrite_cases,
)
from torch.export import export_for_training
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    IS_WINDOWS,
    parametrize,
    run_tests,
    TestCase,
)


@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class ExampleTests(TestCase):
    # TODO Maybe we should make this tests actually show up in a file?
    @parametrize(
        "name,case",
        filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
        name_fn=lambda name, case: f"case_{name}",
    )
    def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
        model = case.model

        args_export = case.example_args
        kwargs_export = case.example_kwargs
        args_model = copy.deepcopy(args_export)
        kwargs_model = copy.deepcopy(kwargs_export)
        exported_program = export_for_training(
            model,
            args_export,
            kwargs_export,
            dynamic_shapes=case.dynamic_shapes,
        )
        exported_program.graph_module.print_readable()

        self.assertEqual(
            exported_program.module()(*args_export, **kwargs_export),
            model(*args_model, **kwargs_model),
        )

        if case.extra_args is not None:
            args = case.extra_args
            args_model = copy.deepcopy(args)
            self.assertEqual(
                exported_program.module()(*args),
                model(*args_model),
            )

    @parametrize(
        "name,case",
        filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
        name_fn=lambda name, case: f"case_{name}",
    )
    def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
        model = case.model
        # pyre-ignore
        with self.assertRaises(
            (torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
        ):
            export_for_training(
                model,
                case.example_args,
                case.example_kwargs,
                dynamic_shapes=case.dynamic_shapes,
            )

    exportdb_not_supported_rewrite_cases = [
        (name, rewrite_case)
        for name, case in filter_examples_by_support_level(
            SupportLevel.NOT_SUPPORTED_YET
        ).items()
        for rewrite_case in get_rewrite_cases(case)
    ]
    if exportdb_not_supported_rewrite_cases:

        @parametrize(
            "name,rewrite_case",
            exportdb_not_supported_rewrite_cases,
            name_fn=lambda name, case: f"case_{name}_{case.name}",
        )
        def test_exportdb_not_supported_rewrite(
            self, name: str, rewrite_case: ExportCase
        ) -> None:
            # pyre-ignore
            export_for_training(
                rewrite_case.model,
                rewrite_case.example_args,
                rewrite_case.example_kwargs,
                dynamic_shapes=rewrite_case.dynamic_shapes,
            )


instantiate_parametrized_tests(ExampleTests)


if __name__ == "__main__":
    run_tests()