1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
|
"""
For procedural tests needed for __torch_function__, we use this function
to export method names and signatures as needed by the tests in
test/test_overrides.py.
python -m tools.autograd.gen_annotated_fn_args \
aten/src/ATen/native/native_functions.yaml \
aten/src/ATen/native/tags.yaml \
$OUTPUT_DIR \
tools/autograd
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/testing/_internal/generated
"""
from __future__ import annotations
import argparse
import os
import textwrap
from collections import defaultdict
from typing import Any, Sequence, TYPE_CHECKING
import torchgen.api.python as python
from torchgen.context import with_native_function
from torchgen.gen import parse_native_yaml
from torchgen.utils import FileManager
from .gen_python_functions import (
is_py_fft_function,
is_py_linalg_function,
is_py_nn_function,
is_py_special_function,
is_py_torch_function,
is_py_variable_method,
should_generate_py_binding,
)
if TYPE_CHECKING:
from torchgen.model import Argument, BaseOperatorName, NativeFunction
def gen_annotated(
native_yaml_path: str, tags_yaml_path: str, out: str, autograd_dir: str
) -> None:
native_functions = parse_native_yaml(
native_yaml_path, tags_yaml_path
).native_functions
mappings = (
(is_py_torch_function, "torch._C._VariableFunctions"),
(is_py_nn_function, "torch._C._nn"),
(is_py_linalg_function, "torch._C._linalg"),
(is_py_special_function, "torch._C._special"),
(is_py_fft_function, "torch._C._fft"),
(is_py_variable_method, "torch.Tensor"),
)
annotated_args: list[str] = []
for pred, namespace in mappings:
groups: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
for f in native_functions:
if not should_generate_py_binding(f) or not pred(f):
continue
groups[f.func.name.name].append(f)
for group in groups.values():
for f in group:
annotated_args.append(f"{namespace}.{gen_annotated_args(f)}")
template_path = os.path.join(autograd_dir, "templates")
fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
fm.write_with_template(
"annotated_fn_args.py",
"annotated_fn_args.py.in",
lambda: {
"annotated_args": textwrap.indent("\n".join(annotated_args), " "),
},
)
@with_native_function
def gen_annotated_args(f: NativeFunction) -> str:
def _get_kwargs_func_exclusion_list() -> list[str]:
# functions that currently don't work with kwargs in test_overrides.py
return [
"diagonal",
"round_",
"round",
"scatter_",
]
def _add_out_arg(
out_args: list[dict[str, Any]], args: Sequence[Argument], *, is_kwarg_only: bool
) -> None:
for arg in args:
if arg.default is not None:
continue
out_arg: dict[str, Any] = {}
out_arg["is_kwarg_only"] = str(is_kwarg_only)
out_arg["name"] = arg.name
out_arg["simple_type"] = python.argument_type_str(
arg.type, simple_type=True
)
size_t = python.argument_type_size(arg.type)
if size_t:
out_arg["size"] = size_t
out_args.append(out_arg)
out_args: list[dict[str, Any]] = []
_add_out_arg(out_args, f.func.arguments.flat_positional, is_kwarg_only=False)
if f"{f.func.name.name}" not in _get_kwargs_func_exclusion_list():
_add_out_arg(out_args, f.func.arguments.flat_kwarg_only, is_kwarg_only=True)
return f"{f.func.name.name}: {repr(out_args)},"
def main() -> None:
parser = argparse.ArgumentParser(description="Generate annotated_fn_args script")
parser.add_argument(
"native_functions", metavar="NATIVE", help="path to native_functions.yaml"
)
parser.add_argument("tags", metavar="TAGS", help="path to tags.yaml")
parser.add_argument("out", metavar="OUT", help="path to output directory")
parser.add_argument(
"autograd", metavar="AUTOGRAD", help="path to template directory"
)
args = parser.parse_args()
gen_annotated(args.native_functions, args.tags, args.out, args.autograd)
if __name__ == "__main__":
main()
|