1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
|
# Owner(s): ["module: onnx"]
from __future__ import annotations
import os
import random
from typing import Any, Mapping, Type
import numpy as np
import onnxruntime
import torch
from torch.onnx import _constants, verification
from torch.testing._internal import common_utils
onnx_model_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
os.pardir,
"repos",
"onnx",
"onnx",
"backend",
"test",
"data",
)
pytorch_converted_dir = os.path.join(onnx_model_dir, "pytorch-converted")
pytorch_operator_dir = os.path.join(onnx_model_dir, "pytorch-operator")
_ORT_PROVIDERS = ("CPUExecutionProvider",)
def run_model_test(test_suite: _TestONNXRuntime, *args, **kwargs):
kwargs["ort_providers"] = _ORT_PROVIDERS
kwargs["opset_version"] = test_suite.opset_version
kwargs["keep_initializers_as_inputs"] = test_suite.keep_initializers_as_inputs
if hasattr(test_suite, "check_shape"):
kwargs["check_shape"] = test_suite.check_shape
if hasattr(test_suite, "check_dtype"):
kwargs["check_dtype"] = test_suite.check_dtype
return verification.verify(*args, **kwargs)
def parameterize_class_name(cls: Type, idx: int, input_dicts: Mapping[Any, Any]):
"""Combine class name with the parameterized arguments.
This function is passed to `parameterized.parameterized_class` as the
`class_name_func` argument.
"""
suffix = "_".join(f"{k}_{v}" for k, v in input_dicts.items())
return f"{cls.__name__}_{suffix}"
def set_rng_seed(seed):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
class _TestONNXRuntime(common_utils.TestCase):
opset_version = _constants.ONNX_DEFAULT_OPSET
keep_initializers_as_inputs = True # For IR version 3 type export.
is_script = False
check_shape = True
check_dtype = True
def setUp(self):
set_rng_seed(0)
onnxruntime.set_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
os.environ["ALLOW_RELEASED_ONNX_OPSET_ONLY"] = "0"
self.is_script_test_enabled = True
# The exported ONNX model may have less inputs than the pytorch model because of const folding.
# This mostly happens in unit test, where we widely use torch.size or torch.shape.
# So the output is only dependent on the input shape, not value.
# remained_onnx_input_idx is used to indicate which pytorch model input idx is remained in ONNX model.
def run_test(
self,
model,
input_args,
input_kwargs=None,
rtol=1e-3,
atol=1e-7,
do_constant_folding=True,
dynamic_axes=None,
additional_test_inputs=None,
input_names=None,
output_names=None,
fixed_batch_size=False,
training=torch.onnx.TrainingMode.EVAL,
remained_onnx_input_idx=None,
verbose=False,
):
def _run_test(m, remained_onnx_input_idx, flatten=True, ignore_none=True):
return run_model_test(
self,
m,
input_args=input_args,
input_kwargs=input_kwargs,
rtol=rtol,
atol=atol,
do_constant_folding=do_constant_folding,
dynamic_axes=dynamic_axes,
additional_test_inputs=additional_test_inputs,
input_names=input_names,
output_names=output_names,
fixed_batch_size=fixed_batch_size,
training=training,
remained_onnx_input_idx=remained_onnx_input_idx,
flatten=flatten,
ignore_none=ignore_none,
verbose=verbose,
)
if isinstance(remained_onnx_input_idx, dict):
scripting_remained_onnx_input_idx = remained_onnx_input_idx["scripting"]
tracing_remained_onnx_input_idx = remained_onnx_input_idx["tracing"]
else:
scripting_remained_onnx_input_idx = remained_onnx_input_idx
tracing_remained_onnx_input_idx = remained_onnx_input_idx
is_model_script = isinstance(
model, (torch.jit.ScriptModule, torch.jit.ScriptFunction)
)
if self.is_script_test_enabled and self.is_script:
script_model = model if is_model_script else torch.jit.script(model)
_run_test(
script_model,
scripting_remained_onnx_input_idx,
flatten=False,
ignore_none=False,
)
if not is_model_script and not self.is_script:
_run_test(model, tracing_remained_onnx_input_idx)
|