File: ort_rewrite.py

package info (click to toggle)
onnxscript 0.2.0%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: experimental
  • size: 12,384 kB
  • sloc: python: 75,957; sh: 41; makefile: 6
file content (75 lines) | stat: -rw-r--r-- 2,230 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Runs onnxruntime rewriter to optimize on the given onnx model.

Input:
    <model-dir>/<model>/<compiler>/<model>_<compiler>.onnx

Output:
    <model-dir>/<model>/<compiler>_<rewritten_name>/<model>_<compiler>_<rewritten_name>.onnx
"""

import argparse
import contextlib
import logging
import os
import shutil

import onnx

from onnxscript.rewriter import onnxruntime as ort_rewriter

logger = logging.getLogger(__name__)


def ort_rewrite(model_name: str, compiler_name: str, model_dir: str):
    old_model_folder = f"{model_dir}/{model_name}/{compiler_name}"
    old_model_name = f"{model_name}_{compiler_name}"

    post_process_name = "ort_rewritten"
    new_model_folder = f"{model_dir}/{model_name}/{compiler_name}_{post_process_name}"
    new_model_name = f"{old_model_name}_{post_process_name}"

    model = onnx.load(f"{old_model_folder}/{old_model_name}.onnx", load_external_data=True)
    ort_rewritten_model = ort_rewriter.rewrite(model)

    with contextlib.suppress(FileNotFoundError):
        shutil.rmtree(new_model_folder)

    if not os.path.exists(new_model_folder):
        os.mkdir(new_model_folder)
        shutil.copytree(
            f"{old_model_folder}/test_data_set_0",
            f"{new_model_folder}/test_data_set_0",
        )

    logger.debug("Model size: %s", ort_rewritten_model.ByteSize())
    onnx.save(
        ort_rewritten_model,
        f"{new_model_folder}/{new_model_name}.onnx",
        save_as_external_data=True,
        all_tensors_to_one_file=True,
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--compiler", type=str, default="dynamo")
    parser.add_argument("--model-dir", "--model_dir", type=str, default="./onnx_models")
    parser.add_argument("--log-level", "--log_level", type=int, default=logging.WARNING)

    args = parser.parse_args()

    model_name = args.model
    compiler_name = args.compiler
    model_dir = args.model_dir

    log_level = args.log_level
    logging.basicConfig(level=log_level)

    ort_rewrite(model_name, compiler_name, model_dir)


if __name__ == "__main__":
    main()