1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
# mypy: ignore-errors
import argparse
import torchgen.model as model
from torchgen.gen import FileManager, parse_native_yaml
def num_leading_spaces(line: str) -> int:
return len(line) - len(line.lstrip())
def deindent(code: str) -> str:
lines = code.split("\n")
min_leading_spaces = min(map(num_leading_spaces, lines))
lines = [line[min_leading_spaces:] for line in lines]
return "\n".join(lines)
def gen_external(native_functions_path, tags_path, external_path):
native_functions = parse_native_yaml(native_functions_path, tags_path)
func_decls = []
func_registrations = []
for func in native_functions:
schema = func.func
name = schema.name.name.base
args = schema.arguments
# Only supports extern calls for functions with out variants
if not schema.is_out_fn():
continue
# Doesn't currently support functions with more than one out parameter
if len(args.out) > 1:
continue
# Doesn't currently support kwarg arguments
if (
len(args.pre_tensor_options_kwarg_only) > 0
or len(args.post_tensor_options_kwarg_only) > 0
):
continue
self_arg = [args.self_arg.argument] if args.self_arg is not None else []
args = (
list(args.pre_self_positional) + self_arg + list(args.post_self_positional)
)
tensor_args = [
arg
for arg in args
if isinstance(arg.type, model.BaseType)
and arg.type.name == model.BaseTy.Tensor
]
if len(tensor_args) != len(args):
continue
arg_names = [None] * len(args)
tensor_decls = []
for idx, arg in enumerate(tensor_args):
s = f"const at::Tensor& {arg.name} = tensors[{idx + 1}];"
tensor_decls.append(s)
arg_names[idx] = arg.name
nl = "\n"
# print(tensor_decls, name, arg_names)
func_decl = f"""\
void nnc_aten_{name}(
int64_t bufs_num,
void** buf_data,
int64_t* buf_ranks,
int64_t* buf_dims,
int64_t* buf_strides,
int8_t* buf_dtypes,
int64_t args_num,
int64_t* extra_args) {{
std::vector<at::Tensor> tensors =
constructTensors(bufs_num, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes);
at::Tensor& r = tensors[0];
{nl.join(tensor_decls)}
try {{
at::{name}_out({', '.join(['r'] + arg_names)});
}} catch (...) {{
}}
}}"""
func_registration = f"""\
const static RegisterNNCExternalFunction nnc_{name}(
"nnc_aten_{name}",
nnc_aten_{name});"""
func_decls.append(func_decl)
func_registrations.append(func_registration)
fm = FileManager(install_dir=".", template_dir=".", dry_run=False)
fm.write_with_template(
"external_functions_codegen.cpp",
external_path,
lambda: {
"external_registrations": func_registrations,
"external_functions": func_decls,
},
)
def main() -> None:
parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
parser.add_argument(
"--native-functions",
"--native_functions",
help="path to native_functions.yaml",
default="../../../../aten/src/ATen/native/native_functions.yaml",
)
parser.add_argument(
"--tags",
help="path to tags.yaml",
default="../../../../aten/src/ATen/native/tags.yaml",
)
parser.add_argument(
"--template-path",
"--template_path",
help="path to external_functions_codegen_template.cpp",
default="../../../../tools/jit/templates/external_functions_codegen_template.cpp",
)
args = parser.parse_args()
gen_external(args.native_functions, args.tags, args.template_path)
if __name__ == "__main__":
main()
|