1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
# Owner(s): ["oncall: export"]
import copy
import unittest
import torch._dynamo as torchdynamo
from torch._export.db.case import ExportCase, SupportLevel
from torch._export.db.examples import (
filter_examples_by_support_level,
get_rewrite_cases,
)
from torch.export import export_for_training
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
IS_WINDOWS,
parametrize,
run_tests,
TestCase,
)
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class ExampleTests(TestCase):
# TODO Maybe we should make this tests actually show up in a file?
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.SUPPORTED).items(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
model = case.model
args_export = case.example_args
kwargs_export = case.example_kwargs
args_model = copy.deepcopy(args_export)
kwargs_model = copy.deepcopy(kwargs_export)
exported_program = export_for_training(
model,
args_export,
kwargs_export,
dynamic_shapes=case.dynamic_shapes,
)
exported_program.graph_module.print_readable()
self.assertEqual(
exported_program.module()(*args_export, **kwargs_export),
model(*args_model, **kwargs_model),
)
if case.extra_args is not None:
args = case.extra_args
args_model = copy.deepcopy(args)
self.assertEqual(
exported_program.module()(*args),
model(*args_model),
)
@parametrize(
"name,case",
filter_examples_by_support_level(SupportLevel.NOT_SUPPORTED_YET).items(),
name_fn=lambda name, case: f"case_{name}",
)
def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None:
model = case.model
# pyre-ignore
with self.assertRaises(
(torchdynamo.exc.Unsupported, AssertionError, RuntimeError)
):
export_for_training(
model,
case.example_args,
case.example_kwargs,
dynamic_shapes=case.dynamic_shapes,
)
exportdb_not_supported_rewrite_cases = [
(name, rewrite_case)
for name, case in filter_examples_by_support_level(
SupportLevel.NOT_SUPPORTED_YET
).items()
for rewrite_case in get_rewrite_cases(case)
]
if exportdb_not_supported_rewrite_cases:
@parametrize(
"name,rewrite_case",
exportdb_not_supported_rewrite_cases,
name_fn=lambda name, case: f"case_{name}_{case.name}",
)
def test_exportdb_not_supported_rewrite(
self, name: str, rewrite_case: ExportCase
) -> None:
# pyre-ignore
export_for_training(
rewrite_case.model,
rewrite_case.example_args,
rewrite_case.example_kwargs,
dynamic_shapes=rewrite_case.dynamic_shapes,
)
instantiate_parametrized_tests(ExampleTests)
if __name__ == "__main__":
run_tests()
|