1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
|
# Owner(s): ["oncall: mobile"]
import torch
from test.jit.fixtures_srcs.generate_models import ALL_MODULES
from torch.testing._internal.common_utils import TestCase, run_tests
class TestUpgraderModelGeneration(TestCase):
def test_all_modules(self):
for a_module, expect_operator in ALL_MODULES.items():
module_name = type(a_module).__name__
self.assertTrue(
isinstance(a_module, torch.nn.Module),
f"The module {module_name} "
f"is not a torch.nn.module instance. "
f"Please ensure it's a subclass of torch.nn.module in fixtures_src.py"
f"and it's registered as an instance in ALL_MODULES in generated_models.py")
if __name__ == '__main__':
run_tests()
|