File: test_cpp_api_parity.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (88 lines) | stat: -rw-r--r-- 3,026 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Owner(s): ["module: cpp"]


import os

from cpp_api_parity import (
    functional_impl_check,
    module_impl_check,
    sample_functional,
    sample_module,
)
from cpp_api_parity.parity_table_parser import parse_parity_tracker_table
from cpp_api_parity.utils import is_torch_nn_functional_test

import torch
import torch.testing._internal.common_nn as common_nn
import torch.testing._internal.common_utils as common


# NOTE: turn this on if you want to print source code of all C++ tests (e.g. for debugging purpose)
PRINT_CPP_SOURCE = False

devices = ["cpu", "cuda"]

PARITY_TABLE_PATH = os.path.join(
    os.path.dirname(__file__), "cpp_api_parity", "parity-tracker.md"
)

parity_table = parse_parity_tracker_table(PARITY_TABLE_PATH)


@torch.testing._internal.common_utils.markDynamoStrictTest
class TestCppApiParity(common.TestCase):
    module_test_params_map = {}
    functional_test_params_map = {}


expected_test_params_dicts = []

for test_params_dicts, test_instance_class in [
    (sample_module.module_tests, common_nn.NewModuleTest),
    (sample_functional.functional_tests, common_nn.NewModuleTest),
    (common_nn.module_tests, common_nn.NewModuleTest),
    (common_nn.get_new_module_tests(), common_nn.NewModuleTest),
    (common_nn.criterion_tests, common_nn.CriterionTest),
]:
    for test_params_dict in test_params_dicts:
        if test_params_dict.get("test_cpp_api_parity", True):
            if is_torch_nn_functional_test(test_params_dict):
                functional_impl_check.write_test_to_test_class(
                    TestCppApiParity,
                    test_params_dict,
                    test_instance_class,
                    parity_table,
                    devices,
                )
            else:
                module_impl_check.write_test_to_test_class(
                    TestCppApiParity,
                    test_params_dict,
                    test_instance_class,
                    parity_table,
                    devices,
                )
            expected_test_params_dicts.append(test_params_dict)

# Assert that all NN module/functional test dicts appear in the parity test
assert len(
    [name for name in TestCppApiParity.__dict__ if "test_torch_nn_" in name]
) == len(expected_test_params_dicts) * len(devices)

# Assert that there exists auto-generated tests for `SampleModule` and `sample_functional`.
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
assert len([name for name in TestCppApiParity.__dict__ if "SampleModule" in name]) == 4
# 4 == 2 (number of test dicts that are not skipped) * 2 (number of devices)
assert (
    len([name for name in TestCppApiParity.__dict__ if "sample_functional" in name])
    == 4
)

module_impl_check.build_cpp_tests(TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE)
functional_impl_check.build_cpp_tests(
    TestCppApiParity, print_cpp_source=PRINT_CPP_SOURCE
)

if __name__ == "__main__":
    common.TestCase._default_dtype_check_enabled = True
    common.run_tests()