1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import contextlib
import os
import sys
import typing
# the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py
from util.ort_format_model.types import FbsTypeInfo # isort:skip
import ort_flatbuffers_py.fbs as fbs # isort:skip
class OrtFormatModelDumper:
"Class to dump an ORT format model."
def __init__(self, model_path: str):
"""
Initialize ORT format model dumper
:param model_path: Path to model
"""
self._file = open(model_path, "rb").read() # noqa: SIM115
self._buffer = bytearray(self._file)
if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0):
raise RuntimeError(f"File does not appear to be a valid ORT format model: '{model_path}'")
self._inference_session = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0)
self._model = self._inference_session.Model()
def _dump_initializers(self, graph: fbs.Graph):
print("Initializers:")
for idx in range(graph.InitializersLength()):
tensor = graph.Initializers(idx)
dims = []
for dim in range(tensor.DimsLength()):
dims.append(tensor.Dims(dim))
print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}")
print("--------")
def _dump_nodeargs(self, graph: fbs.Graph):
print("NodeArgs:")
for idx in range(graph.NodeArgsLength()):
node_arg = graph.NodeArgs(idx)
type = node_arg.Type()
if not type:
# NodeArg for optional value that does not exist
continue
type_str = FbsTypeInfo.typeinfo_to_str(type)
value_type = type.ValueType()
value = type.Value()
dims = None
if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type:
tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape()
tensor_type_and_shape.Init(value.Bytes, value.Pos)
shape = tensor_type_and_shape.Shape()
if shape:
dims = []
for dim in range(shape.DimLength()):
d = shape.Dim(dim).Value()
if d.DimType() == fbs.DimensionValueType.DimensionValueType.VALUE:
dims.append(str(d.DimValue()))
elif d.DimType() == fbs.DimensionValueType.DimensionValueType.PARAM:
dims.append(d.DimParam().decode())
else:
dims.append("?")
else:
dims = None
print(f"{node_arg.Name().decode()} type={type_str} dims={dims}")
print("--------")
def _dump_node(self, node: fbs.Node):
optype = node.OpType().decode()
domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx
since_version = node.SinceVersion()
inputs = [node.Inputs(i).decode() for i in range(node.InputsLength())]
outputs = [node.Outputs(i).decode() for i in range(node.OutputsLength())]
print(
f"{node.Index()}:{node.Name().decode()}({domain}:{optype}:{since_version}) "
f"inputs=[{','.join(inputs)}] outputs=[{','.join(outputs)}]"
)
def _dump_graph(self, graph: fbs.Graph):
"""
Process one level of the Graph, descending into any subgraphs when they are found
"""
self._dump_initializers(graph)
self._dump_nodeargs(graph)
print("Nodes:")
for i in range(graph.NodesLength()):
node = graph.Nodes(i)
self._dump_node(node)
# Read all the attributes
for j in range(node.AttributesLength()):
attr = node.Attributes(j)
attr_type = attr.Type()
if attr_type == fbs.AttributeType.AttributeType.GRAPH:
print(f"## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##")
self._dump_graph(attr.G())
print(f"## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##")
elif attr_type == fbs.AttributeType.AttributeType.GRAPHS:
# the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute
# so entering this 'elif' isn't currently possible
print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##")
for k in range(attr.GraphsLength()):
print(f"## Subgraph {k} ##")
self._dump_graph(attr.Graphs(k))
print(f"## End Subgraph {k} ##")
def dump(self, output: typing.IO):
with contextlib.redirect_stdout(output):
print(f"ORT format version: {self._inference_session.OrtVersion().decode()}")
print("--------")
graph = self._model.Graph()
self._dump_graph(graph)
def parse_args():
parser = argparse.ArgumentParser(
os.path.basename(__file__), description="Dump an ORT format model. Output is to <model_path>.txt"
)
parser.add_argument("--stdout", action="store_true", help="Dump to stdout instead of writing to file.")
parser.add_argument("model_path", help="Path to ORT format model")
args = parser.parse_args()
if not os.path.isfile(args.model_path):
parser.error(f"{args.model_path} is not a file.")
return args
def main():
args = parse_args()
d = OrtFormatModelDumper(args.model_path)
if args.stdout:
d.dump(sys.stdout)
else:
output_filename = args.model_path + ".txt"
with open(output_filename, "w", encoding="utf-8") as ofile:
d.dump(ofile)
if __name__ == "__main__":
main()
|