1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337
|
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test op correctness by comparing with PyTorch results.
Usage:
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py
To run tests on a specific operator (e.g. torch.ceil):
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil
To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention
## Environment variables
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
errors.
"""
from __future__ import annotations
import os
import unittest
from typing import Callable, Optional, Sequence, Tuple
import numpy as np
import onnx
import onnxruntime as ort
import parameterized
import torch
from torch.testing._internal import common_device_type
from torch.testing._internal.opinfo import core as opinfo_core
from torch.utils import _pytree as pytree
import onnxscript
from tests.function_libs.torch_lib import (
error_reproduction,
ops_test_common,
ops_test_data,
)
# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
TESTED_DTYPES = (
torch.float16,
torch.float32,
# Uncomment below item when we really need testing it
# torch.bfloat16,
# torch.float64,
torch.bool,
# torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
)
# NOTE: torch.complex32 is experimental in torch
COMPLEX_TYPES = (torch.complex64,)
def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
"""Returns all dtypes except the ones specified."""
return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)
def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None
for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if not decorator_meta.enabled_if:
# Do not skip the test if the decorator meta is not enabled
continue
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if (
decorator_meta.device_type is not None
and decorator_meta.device_type != device_type
):
# Not applicable for this device_type
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
class TestFunctionValidity(unittest.TestCase):
@parameterized.parameterized.expand(
[(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS]
)
def test_script_function_passes_checker(
self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
):
if not isinstance(torchlib_op_info.op, onnxscript.OnnxFunction):
self.skipTest("Traced functions does not have a function proto")
function_proto = torchlib_op_info.op.to_function_proto()
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
@parameterized.parameterized.expand(
[(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS]
)
def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo):
func = torchlib_op_info.op
schema = func.op_schema
self.assertIsNotNone(schema)
self.assertEqual(schema.name, func.name)
def run_test_output_match(
test_suite: unittest.TestCase,
device: str,
dtype: torch.dtype,
op: opinfo_core.OpInfo,
function_executor: Callable,
tested_op_mapping: dict[
str,
ops_test_data.TorchLibOpInfo,
],
):
"""Base test method for testing each opset, used by instantiate_device_type_tests.
Args:
test_suite: The test class instance.
device: The PyTorch device. instantiate_device_type_tests provides this.
dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
op: The OpInfo instance. instantiate_device_type_tests provides this.
function_executor: The function executor. This is a function that takes
a function and its arguments and returns the output of the function.
tested_op_mapping: The mapping of op name to the tested op.
"""
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
torchlib_op_info = tested_op_mapping[op.name]
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function = torchlib_op_info.op
input_wrangler = torchlib_op_info.input_wrangler
if (
not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
and dtype not in COMPLEX_TYPES
):
test_suite.skipTest(
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
f"Type constraints: {onnx_function.op_schema.type_constraints}"
)
# Obtain the tolerance for the op
rtol, atol = torchlib_op_info.get_tolerance(dtype)
for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with test_suite.subTest(
sample_num=i,
inputs=repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
if isinstance(inp, torch.Tensor)
else inp
for inp in inputs
]
),
kwargs=repr(cpu_sample.kwargs),
):
try:
device_type = cpu_sample.args[0].device.type
except (AttributeError, IndexError):
device_type = "cpu"
test_behavior, reason = _should_skip_xfail_test_sample(
op.name, cpu_sample, dtype, device_type
)
with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]
kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
if input_wrangler:
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)
if isinstance(torch_output, torch.Tensor) and torch.is_complex(torch_output):
torch_output = torch.view_as_real(torch_output.resolve_conj())
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
if (
op.name.startswith("split")
or op.name.startswith("chunk")
or op.name.startswith("unbind")
or op.name
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
):
# Hack for handling split, chunk and unbind which relies on SplitToSequence op.
# Split returns a Sequence that should be treats as a single
# value. So we wrap it into a tuple.
# TODO(justinchuby): Find a more general solution
reference_torch_outputs = [reference_torch_outputs]
test_name = test_suite.id()
function_output = function_executor(test_name, reference_torch_outputs)(
onnx_function, input_onnx, kwargs_onnx
)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
flattened_function_outputs, _ = pytree.tree_flatten(function_output)
assert flattened_torch_outputs
assert len(flattened_torch_outputs) == len(flattened_function_outputs)
for j, (torch_output, function_output) in enumerate(
zip(flattened_torch_outputs, flattened_function_outputs)
):
if not isinstance(function_output, np.ndarray):
# An onnxscript tensor
function_output = function_output.value
actual = torch.tensor(function_output)
expected = (
torch_output
if isinstance(torch_output, torch.Tensor)
else torch.tensor(torch_output)
)
if (
op.name in ops_test_data.NONDETERMINISTIC_OPS
or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
):
# Check shape and dtype only for ops that are known to be
# nondeterministic
test_suite.assertEqual(actual.shape, expected.shape)
test_suite.assertEqual(actual.dtype, expected.dtype)
continue
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
try:
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=True,
check_device=False,
)
except AssertionError as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_mismatch_report(
test_name, i, inputs, cpu_sample.kwargs, actual, expected, e
)
if len(flattened_torch_outputs) > 1:
raise AssertionError(f"Output {j} mismatch") from e
raise
class TestOutputConsistencyFullGraph(unittest.TestCase):
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
This is a parameterized test suite.
"""
def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)
ort.set_seed(42)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[info for info in ops_test_data.OPS_DB if info.name in ops_test_data.TESTED_OPS],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
# Base test method for testing each op by running the full ONNX graph.
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.TORCHLIB_OPINFO_MAPPING,
)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_complex_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
],
allowed_dtypes=COMPLEX_TYPES,
)
def test_complex_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
"""Base test method for testing each op by running the full ONNX graph."""
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.COMPLEX_FUNCTION_MAPPING,
)
common_device_type.instantiate_device_type_tests(
TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"]
)
if __name__ == "__main__":
unittest.main()
|