File: dump_subgraphs.py

package info (click to toggle)
onnxruntime 1.23.2%2Bdfsg-6
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 340,756 kB
  • sloc: cpp: 3,222,136; python: 188,267; ansic: 114,318; asm: 37,927; cs: 36,849; java: 10,962; javascript: 6,811; pascal: 4,126; sh: 2,996; xml: 705; objc: 281; makefile: 67
file content (54 lines) | stat: -rw-r--r-- 1,806 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import os

import onnx


def export_and_recurse(node, attribute, output_dir, level):
    name = node.name
    name = name.replace("/", "_")
    sub_model = onnx.ModelProto()
    sub_model.graph.MergeFrom(attribute.g)
    filename = "L" + str(level) + "_" + node.op_type + "_" + attribute.name + "_" + name + ".onnx"
    onnx.save_model(sub_model, os.path.join(output_dir, filename))
    dump_subgraph(sub_model, output_dir, level + 1)


def dump_subgraph(model, output_dir, level=0):
    graph = model.graph

    for node in graph.node:
        if node.op_type == "Scan" or node.op_type == "Loop":
            body_attribute = next(iter(filter(lambda attr: attr.name == "body", node.attribute)))
            export_and_recurse(node, body_attribute, output_dir, level)
        if node.op_type == "If":
            then_attribute = next(iter(filter(lambda attr: attr.name == "then_branch", node.attribute)))
            else_attribute = next(iter(filter(lambda attr: attr.name == "else_branch", node.attribute)))
            export_and_recurse(node, then_attribute, output_dir, level)
            export_and_recurse(node, else_attribute, output_dir, level)


def parse_args():
    parser = argparse.ArgumentParser(
        os.path.basename(__file__), description="Dump all subgraphs from an ONNX model into separate onnx files."
    )
    parser.add_argument("-m", "--model", required=True, help="model file")
    parser.add_argument("-o", "--out", required=True, help="output directory")
    return parser.parse_args()


def main():
    args = parse_args()

    model_path = args.model
    out = os.path.abspath(args.out)

    if not os.path.exists(out):
        os.makedirs(out)

    model = onnx.load_model(model_path)
    dump_subgraph(model, out)


if __name__ == "__main__":
    main()