1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
|
# mypy: allow-untyped-defs
"""Utilities for manipulating the onnx and onnx-script dependencies and ONNX proto."""
from __future__ import annotations
import glob
import os
import shutil
from typing import Any, Mapping, TYPE_CHECKING
import torch
import torch.jit._trace
import torch.serialization
from torch.onnx import errors
from torch.onnx._internal import jit_utils, registration
if TYPE_CHECKING:
import io
def export_as_test_case(
model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str
) -> str:
"""Export an ONNX model as a self contained ONNX test case.
The test case contains the model and the inputs/outputs data. The directory structure
is as follows:
dir
\u251c\u2500\u2500 test_<name>
\u2502 \u251c\u2500\u2500 model.onnx
\u2502 \u2514\u2500\u2500 test_data_set_0
\u2502 \u251c\u2500\u2500 input_0.pb
\u2502 \u251c\u2500\u2500 input_1.pb
\u2502 \u251c\u2500\u2500 output_0.pb
\u2502 \u2514\u2500\u2500 output_1.pb
Args:
model_bytes: The ONNX model in bytes.
inputs_data: The inputs data, nested data structure of numpy.ndarray.
outputs_data: The outputs data, nested data structure of numpy.ndarray.
Returns:
The path to the test case directory.
"""
try:
import onnx
except ImportError as exc:
raise ImportError(
"Export test case to ONNX format failed: Please install ONNX."
) from exc
test_case_dir = os.path.join(dir, "test_" + name)
os.makedirs(test_case_dir, exist_ok=True)
_export_file(
model_bytes,
os.path.join(test_case_dir, "model.onnx"),
{},
)
data_set_dir = os.path.join(test_case_dir, "test_data_set_0")
if os.path.exists(data_set_dir):
shutil.rmtree(data_set_dir)
os.makedirs(data_set_dir)
proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
for i, (input_proto, input) in enumerate(zip(proto.graph.input, inputs_data)):
export_data(input, input_proto, os.path.join(data_set_dir, f"input_{i}.pb"))
for i, (output_proto, output) in enumerate(zip(proto.graph.output, outputs_data)):
export_data(output, output_proto, os.path.join(data_set_dir, f"output_{i}.pb"))
return test_case_dir
def load_test_case(dir: str) -> tuple[bytes, Any, Any]:
"""Load a self contained ONNX test case from a directory.
The test case must contain the model and the inputs/outputs data. The directory structure
should be as follows:
dir
\u251c\u2500\u2500 test_<name>
\u2502 \u251c\u2500\u2500 model.onnx
\u2502 \u2514\u2500\u2500 test_data_set_0
\u2502 \u251c\u2500\u2500 input_0.pb
\u2502 \u251c\u2500\u2500 input_1.pb
\u2502 \u251c\u2500\u2500 output_0.pb
\u2502 \u2514\u2500\u2500 output_1.pb
Args:
dir: The directory containing the test case.
Returns:
model_bytes: The ONNX model in bytes.
inputs: the inputs data, mapping from input name to numpy.ndarray.
outputs: the outputs data, mapping from output name to numpy.ndarray.
"""
try:
import onnx
from onnx import numpy_helper # type: ignore[attr-defined]
except ImportError as exc:
raise ImportError(
"Load test case from ONNX format failed: Please install ONNX."
) from exc
with open(os.path.join(dir, "model.onnx"), "rb") as f:
model_bytes = f.read()
test_data_dir = os.path.join(dir, "test_data_set_0")
inputs = {}
input_files = glob.glob(os.path.join(test_data_dir, "input_*.pb"))
for input_file in input_files:
tensor = onnx.load_tensor(input_file) # type: ignore[attr-defined]
inputs[tensor.name] = numpy_helper.to_array(tensor)
outputs = {}
output_files = glob.glob(os.path.join(test_data_dir, "output_*.pb"))
for output_file in output_files:
tensor = onnx.load_tensor(output_file) # type: ignore[attr-defined]
outputs[tensor.name] = numpy_helper.to_array(tensor)
return model_bytes, inputs, outputs
def export_data(data, value_info_proto, f: str) -> None:
"""Export data to ONNX protobuf format.
Args:
data: The data to export, nested data structure of numpy.ndarray.
value_info_proto: The ValueInfoProto of the data. The type of the ValueInfoProto
determines how the data is stored.
f: The file to write the data to.
"""
try:
from onnx import numpy_helper # type: ignore[attr-defined]
except ImportError as exc:
raise ImportError(
"Export data to ONNX format failed: Please install ONNX."
) from exc
with open(f, "wb") as opened_file:
if value_info_proto.type.HasField("map_type"):
opened_file.write(
numpy_helper.from_dict(data, value_info_proto.name).SerializeToString()
)
elif value_info_proto.type.HasField("sequence_type"):
opened_file.write(
numpy_helper.from_list(data, value_info_proto.name).SerializeToString()
)
elif value_info_proto.type.HasField("optional_type"):
opened_file.write(
numpy_helper.from_optional(
data, value_info_proto.name
).SerializeToString()
)
else:
assert value_info_proto.type.HasField("tensor_type")
opened_file.write(
numpy_helper.from_array(data, value_info_proto.name).SerializeToString()
)
def _export_file(
model_bytes: bytes,
f: io.BytesIO | str,
export_map: Mapping[str, bytes],
) -> None:
"""export/write model bytes into directory/protobuf/zip"""
assert len(export_map) == 0
with torch.serialization._open_file_like(f, "wb") as opened_file:
opened_file.write(model_bytes)
def _add_onnxscript_fn(
model_bytes: bytes,
custom_opsets: Mapping[str, int],
) -> bytes:
"""Insert model-included custom onnx-script function into ModelProto"""
try:
import onnx
except ImportError as e:
raise errors.OnnxExporterError("Module onnx is not installed!") from e
# For > 2GB model, onnx.load_fromstring would fail. However, because
# in _export_onnx, the tensors should be saved separately if the proto
# size > 2GB, and if it for some reason did not, the model would fail on
# serialization anyway in terms of the protobuf limitation. So we don't
# need to worry about > 2GB model getting here.
model_proto = onnx.load_model_from_string(model_bytes) # type: ignore[attr-defined]
# Iterate graph nodes to insert only the included custom
# function_proto into model_proto
onnx_function_list = [] # type: ignore[var-annotated]
included_node_func: set[str] = set()
# onnx_function_list and included_node_func are expanded in-place
_find_onnxscript_op(
model_proto.graph, included_node_func, custom_opsets, onnx_function_list
)
if onnx_function_list:
model_proto.functions.extend(onnx_function_list)
model_bytes = model_proto.SerializeToString()
return model_bytes
def _find_onnxscript_op(
graph_proto,
included_node_func: set[str],
custom_opsets: Mapping[str, int],
onnx_function_list: list,
):
"""Recursively iterate ModelProto to find ONNXFunction op as it may contain control flow Op."""
for node in graph_proto.node:
node_kind = node.domain + "::" + node.op_type
# Recursive needed for control flow nodes: IF/Loop which has inner graph_proto
for attr in node.attribute:
if attr.g is not None:
_find_onnxscript_op(
attr.g, included_node_func, custom_opsets, onnx_function_list
)
# Only custom Op with ONNX function and aten with symbolic_fn should be found in registry
onnx_function_group = registration.registry.get_function_group(node_kind)
# Ruled out corner cases: onnx/prim in registry
if (
node.domain
and not jit_utils.is_aten(node.domain)
and not jit_utils.is_prim(node.domain)
and not jit_utils.is_onnx(node.domain)
and onnx_function_group is not None
and node_kind not in included_node_func
):
specified_version = custom_opsets.get(node.domain, 1)
onnx_fn = onnx_function_group.get(specified_version)
if onnx_fn is not None:
if hasattr(onnx_fn, "to_function_proto"):
onnx_function_proto = onnx_fn.to_function_proto() # type: ignore[attr-defined]
onnx_function_list.append(onnx_function_proto)
included_node_func.add(node_kind)
continue
raise errors.UnsupportedOperatorError(
node_kind,
specified_version,
onnx_function_group.get_min_supported()
if onnx_function_group
else None,
)
return onnx_function_list, included_node_func
|