File: dump_ort_model.py

package info (click to toggle)
onnxruntime 1.23.2%2Bdfsg-6
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 340,756 kB
  • sloc: cpp: 3,222,136; python: 188,267; ansic: 114,318; asm: 37,927; cs: 36,849; java: 10,962; javascript: 6,811; pascal: 4,126; sh: 2,996; xml: 705; objc: 281; makefile: 67
file content (151 lines) | stat: -rw-r--r-- 6,063 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import argparse
import contextlib
import os
import sys
import typing

# the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py
from util.ort_format_model.types import FbsTypeInfo  # isort:skip
import ort_flatbuffers_py.fbs as fbs  # isort:skip


class OrtFormatModelDumper:
    "Class to dump an ORT format model."

    def __init__(self, model_path: str):
        """
        Initialize ORT format model dumper
        :param model_path: Path to model
        """
        self._file = open(model_path, "rb").read()  # noqa: SIM115
        self._buffer = bytearray(self._file)
        if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
            raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
        self._inference_session = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0)
        self._model = self._inference_session.Model()

    def _dump_initializers(self, graph: fbs.Graph):
        print("Initializers:")
        for idx in range(graph.InitializersLength()):
            tensor = graph.Initializers(idx)
            dims = []
            for dim in range(tensor.DimsLength()):
                dims.append(tensor.Dims(dim))

            print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}")
        print("--------")

    def _dump_nodeargs(self, graph: fbs.Graph):
        print("NodeArgs:")
        for idx in range(graph.NodeArgsLength()):
            node_arg = graph.NodeArgs(idx)
            type = node_arg.Type()
            if not type:
                # NodeArg for optional value that does not exist
                continue

            type_str = FbsTypeInfo.typeinfo_to_str(type)
            value_type = type.ValueType()
            value = type.Value()
            dims = None
            if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
                tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
                tensor_type_and_shape.Init(value.Bytes, value.Pos)
                shape = tensor_type_and_shape.Shape()
                if shape:
                    dims = []
                    for dim in range(shape.DimLength()):
                        d = shape.Dim(dim).Value()
                        if d.DimType() == fbs.DimensionValueType.DimensionValueType.VALUE:
                            dims.append(str(d.DimValue()))
                        elif d.DimType() == fbs.DimensionValueType.DimensionValueType.PARAM:
                            dims.append(d.DimParam().decode())
                        else:
                            dims.append("?")
            else:
                dims = None

            print(f"{node_arg.Name().decode()} type={type_str} dims={dims}")
        print("--------")

    def _dump_node(self, node: fbs.Node):
        optype = node.OpType().decode()
        domain = node.Domain().decode() or "ai.onnx"  # empty domain defaults to ai.onnx
        since_version = node.SinceVersion()

        inputs = [node.Inputs(i).decode() for i in range(node.InputsLength())]
        outputs = [node.Outputs(i).decode() for i in range(node.OutputsLength())]
        print(
            f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) "
            f"inputs=[{','.join(inputs)}] outputs=[{','.join(outputs)}]"
        )

    def _dump_graph(self, graph: fbs.Graph):
        """
        Process one level of the Graph, descending into any subgraphs when they are found
        """

        self._dump_initializers(graph)
        self._dump_nodeargs(graph)
        print("Nodes:")
        for i in range(graph.NodesLength()):
            node = graph.Nodes(i)
            self._dump_node(node)

            # Read all the attributes
            for j in range(node.AttributesLength()):
                attr = node.Attributes(j)
                attr_type = attr.Type()
                if attr_type == fbs.AttributeType.AttributeType.GRAPH:
                    print(f"## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##")
                    self._dump_graph(attr.G())
                    print(f"## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##")
                elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
                    # the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
                    # so entering this 'elif' isn't currently possible
                    print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##")
                    for k in range(attr.GraphsLength()):
                        print(f"## Subgraph {k} ##")
                        self._dump_graph(attr.Graphs(k))
                        print(f"## End Subgraph {k} ##")

    def dump(self, output: typing.IO):
        with contextlib.redirect_stdout(output):
            print(f"ORT format version: {self._inference_session.OrtVersion().decode()}")
            print("--------")

            graph = self._model.Graph()
            self._dump_graph(graph)


def parse_args():
    parser = argparse.ArgumentParser(
        os.path.basename(__file__), description="Dump an ORT format model. Output is to <model_path>.txt"
    )
    parser.add_argument("--stdout", action="store_true", help="Dump to stdout instead of writing to file.")
    parser.add_argument("model_path", help="Path to ORT format model")
    args = parser.parse_args()

    if not os.path.isfile(args.model_path):
        parser.error(f"{args.model_path} is not a file.")

    return args


def main():
    args = parse_args()
    d = OrtFormatModelDumper(args.model_path)

    if args.stdout:
        d.dump(sys.stdout)
    else:
        output_filename = args.model_path + ".txt"
        with open(output_filename, "w", encoding="utf-8") as ofile:
            d.dump(ofile)


if __name__ == "__main__":
    main()