File: gen_diagnostics.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (257 lines) | stat: -rw-r--r-- 7,698 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
#!/usr/bin/env python3

""" Generates PyTorch ONNX Export Diagnostic rules for C++, Python and documentations.
The rules are defined in torch/onnx/_internal/diagnostics/rules.yaml.

Usage:

python -m tools.onnx.gen_diagnostics \
    torch/onnx/_internal/diagnostics/rules.yaml \
    torch/onnx/_internal/diagnostics \
    torch/csrc/onnx/diagnostics/generated \
    torch/docs/source
"""

import argparse
import os
import string
import subprocess
import textwrap
from typing import Any, Mapping, Sequence

import yaml

from torchgen import utils as torchgen_utils
from torchgen.yaml_utils import YamlLoader


_RULES_GENERATED_COMMENT = """\
GENERATED CODE - DO NOT EDIT DIRECTLY
This file is generated by gen_diagnostics.py.
See tools/onnx/gen_diagnostics.py for more information.

Diagnostic rules for PyTorch ONNX export.
"""

_PY_RULE_CLASS_COMMENT = """\
GENERATED CODE - DO NOT EDIT DIRECTLY
The purpose of generating a class for each rule is to override the `format_message`
method to provide more details in the signature about the format arguments.
"""

_PY_RULE_CLASS_TEMPLATE = """\
class _{pascal_case_name}(infra.Rule):
    \"\"\"{short_description}\"\"\"
    def format_message(  # type: ignore[override]
        self,
        {message_arguments}
    ) -> str:
        \"\"\"Returns the formatted default message of this Rule.

        Message template: {message_template}
        \"\"\"
        return self.message_default_template.format({message_arguments_assigned})

    def format(  # type: ignore[override]
        self,
        level: infra.Level,
        {message_arguments}
    ) -> Tuple[infra.Rule, infra.Level, str]:
        \"\"\"Returns a tuple of (Rule, Level, message) for this Rule.

        Message template: {message_template}
        \"\"\"
        return self, level, self.format_message({message_arguments_assigned})

"""

_PY_RULE_COLLECTION_FIELD_TEMPLATE = """\
{snake_case_name}: _{pascal_case_name} = dataclasses.field(
    default=_{pascal_case_name}.from_sarif(**{sarif_dict}),
    init=False,
)
\"\"\"{short_description}\"\"\"
"""

_CPP_RULE_TEMPLATE = """\
/**
 * @brief {short_description}
 */
{name},
"""

_RuleType = Mapping[str, Any]


def _kebab_case_to_snake_case(name: str) -> str:
    return name.replace("-", "_")


def _kebab_case_to_pascal_case(name: str) -> str:
    return "".join(word.capitalize() for word in name.split("-"))


def _format_rule_for_python_class(rule: _RuleType) -> str:
    pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
    short_description = rule["short_description"]["text"]
    message_template = rule["message_strings"]["default"]["text"]
    field_names = [
        field_name
        for _, field_name, _, _ in string.Formatter().parse(message_template)
        if field_name is not None
    ]
    for field_name in field_names:
        assert isinstance(
            field_name, str
        ), f"Unexpected field type {type(field_name)} from {field_name}. "
        "Field name must be string.\nFull message template: {message_template}"
        assert (
            not field_name.isnumeric()
        ), f"Unexpected numeric field name {field_name}. "
        "Only keyword name formatting is supported.\nFull message template: {message_template}"
    message_arguments = ", ".join(field_names)
    message_arguments_assigned = ", ".join(
        [f"{field_name}={field_name}" for field_name in field_names]
    )
    return _PY_RULE_CLASS_TEMPLATE.format(
        pascal_case_name=pascal_case_name,
        short_description=short_description,
        message_template=repr(message_template),
        message_arguments=message_arguments,
        message_arguments_assigned=message_arguments_assigned,
    )


def _format_rule_for_python_field(rule: _RuleType) -> str:
    snake_case_name = _kebab_case_to_snake_case(rule["name"])
    pascal_case_name = _kebab_case_to_pascal_case(rule["name"])
    short_description = rule["short_description"]["text"]

    return _PY_RULE_COLLECTION_FIELD_TEMPLATE.format(
        snake_case_name=snake_case_name,
        pascal_case_name=pascal_case_name,
        sarif_dict=rule,
        short_description=short_description,
    )


def _format_rule_for_cpp(rule: _RuleType) -> str:
    name = f"k{_kebab_case_to_pascal_case(rule['name'])}"
    short_description = rule["short_description"]["text"]
    return _CPP_RULE_TEMPLATE.format(name=name, short_description=short_description)


def gen_diagnostics_python(
    rules: Sequence[_RuleType], out_py_dir: str, template_dir: str
) -> None:
    rule_class_lines = [_format_rule_for_python_class(rule) for rule in rules]
    rule_field_lines = [_format_rule_for_python_field(rule) for rule in rules]

    fm = torchgen_utils.FileManager(
        install_dir=out_py_dir, template_dir=template_dir, dry_run=False
    )
    fm.write_with_template(
        "_rules.py",
        "rules.py.in",
        lambda: {
            "generated_comment": _RULES_GENERATED_COMMENT,
            "generated_rule_class_comment": _PY_RULE_CLASS_COMMENT,
            "rule_classes": "\n".join(rule_class_lines),
            "rules": textwrap.indent("\n".join(rule_field_lines), " " * 4),
        },
    )
    _lint_file(os.path.join(out_py_dir, "_rules.py"))


def gen_diagnostics_cpp(
    rules: Sequence[_RuleType], out_cpp_dir: str, template_dir: str
) -> None:
    rule_lines = [_format_rule_for_cpp(rule) for rule in rules]
    rule_names = [f'"{_kebab_case_to_snake_case(rule["name"])}",' for rule in rules]

    fm = torchgen_utils.FileManager(
        install_dir=out_cpp_dir, template_dir=template_dir, dry_run=False
    )
    fm.write_with_template(
        "rules.h",
        "rules.h.in",
        lambda: {
            "generated_comment": textwrap.indent(
                _RULES_GENERATED_COMMENT,
                " * ",
                predicate=lambda x: True,  # Don't ignore empty line
            ),
            "rules": textwrap.indent("\n".join(rule_lines), " " * 2),
            "py_rule_names": textwrap.indent("\n".join(rule_names), " " * 4),
        },
    )
    _lint_file(os.path.join(out_cpp_dir, "rules.h"))


def gen_diagnostics_docs(
    rules: Sequence[_RuleType], out_docs_dir: str, template_dir: str
) -> None:
    # TODO: Add doc generation in a follow-up PR.
    pass


def _lint_file(file_path: str) -> None:
    p = subprocess.Popen(["lintrunner", "-a", file_path])
    p.wait()


def gen_diagnostics(
    rules_path: str,
    out_py_dir: str,
    out_cpp_dir: str,
    out_docs_dir: str,
) -> None:
    with open(rules_path) as f:
        rules = yaml.load(f, Loader=YamlLoader)

    template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")

    gen_diagnostics_python(
        rules,
        out_py_dir,
        template_dir,
    )

    gen_diagnostics_cpp(
        rules,
        out_cpp_dir,
        template_dir,
    )

    gen_diagnostics_docs(rules, out_docs_dir, template_dir)


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate ONNX diagnostics files")
    parser.add_argument("rules_path", metavar="RULES", help="path to rules.yaml")
    parser.add_argument(
        "out_py_dir",
        metavar="OUT_PY",
        help="path to output directory for Python",
    )
    parser.add_argument(
        "out_cpp_dir",
        metavar="OUT_CPP",
        help="path to output directory for C++",
    )
    parser.add_argument(
        "out_docs_dir",
        metavar="OUT_DOCS",
        help="path to output directory for docs",
    )
    args = parser.parse_args()
    gen_diagnostics(
        args.rules_path,
        args.out_py_dir,
        args.out_cpp_dir,
        args.out_docs_dir,
    )


if __name__ == "__main__":
    main()