File: build_opsets.py

package info (click to toggle)
pytorch 2.9.1%2Bdfsg-1~exp2
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 180,096 kB
  • sloc: python: 1,473,255; cpp: 942,030; ansic: 79,796; asm: 7,754; javascript: 2,502; java: 1,962; sh: 1,809; makefile: 628; xml: 8
file content (75 lines) | stat: -rw-r--r-- 2,043 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
from collections import OrderedDict
from pathlib import Path

import torch
import torch._prims as prims
from torchgen.gen import parse_native_yaml


ROOT = Path(__file__).absolute().parents[3]
NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml"
TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml"

BUILD_DIR = "build/ir"
ATEN_OPS_CSV_FILE = "aten_ops.csv"
PRIMS_OPS_CSV_FILE = "prims_ops.csv"


def get_aten():
    parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH)
    native_functions = parsed_yaml.native_functions

    aten_ops = OrderedDict()
    for function in native_functions:
        if "core" in function.tags:
            op_name = str(function.func.name)
            aten_ops[op_name] = function

    op_schema_pairs = []
    for key, op in sorted(aten_ops.items()):
        op_name = f"aten.{key}"
        schema = str(op.func).replace("*", r"\*")

        op_schema_pairs.append((op_name, schema))

    return op_schema_pairs


def get_prims():
    op_schema_pairs = []
    for op_name in prims.__all__:
        op_overload = getattr(prims, op_name, None)

        if not isinstance(op_overload, torch._ops.OpOverload):
            continue

        op_overloadpacket = op_overload.overloadpacket

        op_name = str(op_overload).replace(".default", "")
        schema = op_overloadpacket.schema.replace("*", r"\*")

        op_schema_pairs.append((op_name, schema))

    return op_schema_pairs


def main():
    aten_ops_list = get_aten()
    prims_ops_list = get_prims()

    os.makedirs(BUILD_DIR, exist_ok=True)

    with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f:
        f.write("Operator,Schema\n")
        for name, schema in aten_ops_list:
            f.write(f'"``{name}``","{schema}"\n')

    with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f:
        f.write("Operator,Schema\n")
        for name, schema in prims_ops_list:
            f.write(f'"``{name}``","{schema}"\n')


if __name__ == "__main__":
    main()