File: __main__.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (101 lines) | stat: -rw-r--r-- 3,052 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import argparse
import shutil
import subprocess
import textwrap
from pathlib import Path

from onnx_opset_builder import (
    OpsetId,
    OpsetsBuilder,
    format_opsetid,
    parse_opsetid,
)

MIN_REQUIRED_ONNX_OPSET_VERSION = 14

self_dir = Path(__file__).parent
repo_root = self_dir.parent

module_base_names = ["onnxscript", "onnx_opset"]
opsets_path = repo_root.joinpath(*module_base_names)

argparser = argparse.ArgumentParser("opgen")
argparser.add_argument(
    "-x",
    "--exclude",
    action="append",
    metavar="OPSET",
    dest="exclude_opsets",
    help="exclude an opset from generation; example: -x 19 -x ai.onnx.ml/3",
)
argparser.add_argument(
    "-i",
    "--include-only",
    action="append",
    metavar="OPSET",
    dest="include_opsets",
    help="include only these opsets; example: -i 19",
)
argparser.add_argument(
    "--min-opset-version",
    help="the minimum supported ONNX opset version",
    default=MIN_REQUIRED_ONNX_OPSET_VERSION,
    action="store",
    type=int,
)
args = argparser.parse_args()

try:  # noqa: SIM105
    shutil.rmtree(opsets_path)
except FileNotFoundError:
    pass  # if base_path doesn't exist, that's great

# need to generate a blank onnx_opset module since
# onnxscript/__init__.py will import it (and we deleted it above);
# it will be overridden with correct code as part of the generation
# below.

opsets_path.mkdir(parents=True)
with opsets_path.joinpath("__init__.py").open("w", encoding="utf-8"):
    pass


builder = OpsetsBuilder(
    module_base_name=".".join(module_base_names),
    min_default_opset_version=args.min_opset_version,
    include_opsets={parse_opsetid(opsetid) for opsetid in args.include_opsets or []},
    exclude_opsets={parse_opsetid(opsetid) for opsetid in args.exclude_opsets or []},
)
result = builder.build()
paths = result.write(repo_root)
subprocess.check_call(["black", "--quiet", *paths])
subprocess.check_call(["isort", "--quiet", *paths])

print(f"🎉 Generated Ops: {result.all_ops_count}")
print(f"   Minimum Opset Version: {args.min_opset_version}")
print()


def print_opsets(label: str, opsets: set[OpsetId]):
    if any(opsets):
        print(label)
        summary = ", ".join([format_opsetid(i) for i in sorted(opsets)])
        print("\n".join(textwrap.wrap(summary, initial_indent="   ", subsequent_indent="   ")))
        print()


print_opsets("🟢 Included Opsets:", result.included_opsets)
print_opsets("🔴 Excluded Opsets:", result.excluded_opsets)

if any(result.unsupported_ops):
    print("🟠 Unsupported Ops:")
    for key, unsupported_ops in sorted(result.unsupported_ops.items()):
        print(f"   reason: {key}:")
        for unsupported_op in unsupported_ops:
            print(f"     - {unsupported_op.op}")
            print(f"       {unsupported_op.op.docuri}")