1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
|
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
import yaml
from torchgen.code_template import CodeTemplate
from torchgen.selective_build.selector import SelectiveBuilder
# Safely load fast C Yaml loader/dumper if they are available
try:
from yaml import CSafeLoader as Loader
except ImportError:
from yaml import SafeLoader as Loader # type: ignore[assignment, misc]
if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) {
return $dtype_checks;
}"""
if_condition_template = CodeTemplate(if_condition_template_str)
selected_kernel_dtypes_h_template_str = """
#include <c10/core/ScalarType.h>
#include <c10/util/string_view.h>
#include <c10/macros/Macros.h>
namespace at {
inline constexpr bool should_include_kernel_dtype(
const char *kernel_tag_str,
at::ScalarType scalar_type
) {
[[maybe_unused]] c10::string_view kernel_tag_sv =
c10::string_view(kernel_tag_str);
$body return false;
}
}
"""
selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str)
selected_mobile_ops_preamble = """#pragma once
/**
* Generated by gen_selected_mobile_ops_header.py
*/
"""
def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]:
ops = []
for op_name, op in selective_builder.operators.items():
if op.is_root_operator:
ops.append(op_name)
return set(ops)
def get_selected_kernel_dtypes_code(
selective_builder: SelectiveBuilder,
) -> str:
# See https://www.internalfb.com/intern/paste/P153411698/ for an example of the
# generated code in case all kernel dtypes are selected and in case some kernel
# dtypes are selected (i.e. both cases).
#
body = "return true;"
if (
selective_builder.include_all_operators is False
and selective_builder.include_all_non_op_selectives is False
):
body_parts = []
for kernel_tag, dtypes in selective_builder.kernel_metadata.items():
conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes]
body_parts.append(
if_condition_template.substitute(
kernel_tag_name=kernel_tag,
dtype_checks=" || ".join(conditions),
),
)
body = " else ".join(body_parts)
header_contents = selected_kernel_dtypes_h_template.substitute(body=body)
return header_contents
# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators
# 2. The selected kernel dtypes
def write_selected_mobile_ops(
output_file_path: str,
selective_builder: SelectiveBuilder,
) -> None:
root_ops = extract_root_operators(selective_builder)
custom_classes = selective_builder.custom_classes
build_features = selective_builder.build_features
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]
# This condition checks if we are in selective build.
# if these lists are not defined the corresponding selective build macros trivially return the item in question was selected
if not selective_builder.include_all_operators:
body_parts.append(
"#define TORCH_OPERATOR_WHITELIST "
+ (";".join(sorted(root_ops)))
+ ";\n\n"
)
# This condition checks if we are in tracing based selective build
if selective_builder.include_all_non_op_selectives is False:
body_parts.append(
"#define TORCH_CUSTOM_CLASS_ALLOWLIST "
+ (";".join(sorted(custom_classes)))
+ ";\n\n"
)
body_parts.append(
"#define TORCH_BUILD_FEATURE_ALLOWLIST "
+ (";".join(sorted(build_features)))
+ ";\n\n"
)
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
header_contents = "".join(body_parts)
out_file.write(header_contents.encode("utf-8"))
# root_ops: a set of selected root operators for selective build
# Write the file selected_mobile_ops.h with optionally:
# 1. The selected root operators from root_ops
# 2. All kernel dtypes
def write_selected_mobile_ops_with_all_dtypes(
output_file_path: str,
root_ops: set[str],
) -> None:
with open(output_file_path, "wb") as out_file:
body_parts = [selected_mobile_ops_preamble]
body_parts.append(
"#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n"
)
selective_builder = SelectiveBuilder.get_nop_selector()
body_parts.append(get_selected_kernel_dtypes_code(selective_builder))
header_contents = "".join(body_parts)
out_file.write(header_contents.encode("utf-8"))
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate selected_mobile_ops.h for selective build."
)
parser.add_argument(
"-p",
"--yaml-file-path",
"--yaml_file_path",
type=str,
required=True,
help="Path to the yaml file with a list of operators used by the model.",
)
parser.add_argument(
"-o",
"--output-file-path",
"--output_file_path",
type=str,
required=True,
help="Path to destination"
"folder where selected_mobile_ops.h will be written.",
)
parsed_args = parser.parse_args()
model_file_name = parsed_args.yaml_file_path
print("Loading yaml file: ", model_file_name)
loaded_model = {}
with open(model_file_name, "rb") as model_file:
loaded_model = yaml.load(model_file, Loader=Loader)
root_operators_set = set(loaded_model)
print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path)
write_selected_mobile_ops_with_all_dtypes(
os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"),
root_operators_set,
)
if __name__ == "__main__":
main()
|