File: gen_selected_mobile_ops_header.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (184 lines) | stat: -rw-r--r-- 6,097 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python3

from __future__ import annotations

import argparse
import os

import yaml

from torchgen.code_template import CodeTemplate
from torchgen.selective_build.selector import SelectiveBuilder


# Safely load fast C Yaml loader/dumper if they are available
try:
    from yaml import CSafeLoader as Loader
except ImportError:
    from yaml import SafeLoader as Loader  # type: ignore[assignment, misc]


if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
  return $dtype_checks;
}"""
if_condition_template = CodeTemplate(if_condition_template_str)

selected_kernel_dtypes_h_template_str = """
#include <c10/core/ScalarType.h>
#include <c10/util/string_view.h>
#include <c10/macros/Macros.h>

namespace at {
inline constexpr bool should_include_kernel_dtype(
  const char *kernel_tag_str,
  at::ScalarType scalar_type
) {
  [[maybe_unused]] c10::string_view kernel_tag_sv =
      c10::string_view(kernel_tag_str);
  $body return false;
}
}
"""
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)

selected_mobile_ops_preamble = """#pragma once
/**
 * Generated by gen_selected_mobile_ops_header.py
 */

"""


def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
    ops = []
    for op_name, op in selective_builder.operators.items():
        if op.is_root_operator:
            ops.append(op_name)
    return set(ops)


def get_selected_kernel_dtypes_code(
    selective_builder: SelectiveBuilder,
) -> str:
    # See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
    # generated code in case all kernel dtypes are selected and in case some kernel
    # dtypes are selected (i.e. both cases).
    #
    body = "return true;"
    if (
        selective_builder.include_all_operators is False
        and selective_builder.include_all_non_op_selectives is False
    ):
        body_parts = []
        for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
            conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
            body_parts.append(
                if_condition_template.substitute(
                    kernel_tag_name=kernel_tag,
                    dtype_checks=" || ".join(conditions),
                ),
            )
        body = " else ".join(body_parts)

    header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
    return header_contents


# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators
# 2. The selected kernel dtypes
def write_selected_mobile_ops(
    output_file_path: str,
    selective_builder: SelectiveBuilder,
) -> None:
    root_ops = extract_root_operators(selective_builder)
    custom_classes = selective_builder.custom_classes
    build_features = selective_builder.build_features
    with open(output_file_path, "wb") as out_file:
        body_parts = [selected_mobile_ops_preamble]
        # This condition checks if we are in selective build.
        # if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
        if not selective_builder.include_all_operators:
            body_parts.append(
                "#define TORCH_OPERATOR_WHITELIST "
                + (";".join(sorted(root_ops)))
                + ";\n\n"
            )
            # This condition checks if we are in tracing based selective build
            if selective_builder.include_all_non_op_selectives is False:
                body_parts.append(
                    "#define TORCH_CUSTOM_CLASS_ALLOWLIST "
                    + (";".join(sorted(custom_classes)))
                    + ";\n\n"
                )
                body_parts.append(
                    "#define TORCH_BUILD_FEATURE_ALLOWLIST "
                    + (";".join(sorted(build_features)))
                    + ";\n\n"
                )

        body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
        header_contents = "".join(body_parts)
        out_file.write(header_contents.encode("utf-8"))


# root_ops: a set of selected root operators for selective build
# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators from root_ops
# 2. All kernel dtypes
def write_selected_mobile_ops_with_all_dtypes(
    output_file_path: str,
    root_ops: set[str],
) -> None:
    with open(output_file_path, "wb") as out_file:
        body_parts = [selected_mobile_ops_preamble]
        body_parts.append(
            "#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
        )

        selective_builder = SelectiveBuilder.get_nop_selector()
        body_parts.append(get_selected_kernel_dtypes_code(selective_builder))

        header_contents = "".join(body_parts)
        out_file.write(header_contents.encode("utf-8"))


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate selected_mobile_ops.h for selective build."
    )
    parser.add_argument(
        "-p",
        "--yaml-file-path",
        "--yaml_file_path",
        type=str,
        required=True,
        help="Path to the yaml file with a list of operators used by the model.",
    )
    parser.add_argument(
        "-o",
        "--output-file-path",
        "--output_file_path",
        type=str,
        required=True,
        help="Path to destination"
        "folder where selected_mobile_ops.h will be written.",
    )
    parsed_args = parser.parse_args()
    model_file_name = parsed_args.yaml_file_path

    print("Loading yaml file: ", model_file_name)
    loaded_model = {}
    with open(model_file_name, "rb") as model_file:
        loaded_model = yaml.load(model_file, Loader=Loader)

    root_operators_set = set(loaded_model)
    print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
    write_selected_mobile_ops_with_all_dtypes(
        os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
        root_operators_set,
    )


if __name__ == "__main__":
    main()