1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
|
"""Test utilities for ONNX export."""
from __future__ import annotations
__all__ = ["assert_onnx_program"]
from typing import Any, TYPE_CHECKING
import torch
from torch.utils import _pytree
if TYPE_CHECKING:
from torch.onnx._internal.exporter import _onnx_program
def assert_onnx_program(
program: _onnx_program.ONNXProgram,
*,
rtol: float | None = None,
atol: float | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
) -> None:
"""Assert that the ONNX model produces the same output as the PyTorch ExportedProgram.
Args:
program: The ``ONNXProgram`` to verify.
rtol: Relative tolerance.
atol: Absolute tolerance.
args: The positional arguments to pass to the program.
If None, the default example inputs in the ExportedProgram will be used.
kwargs: The keyword arguments to pass to the program.
If None, the default example inputs in the ExportedProgram will be used.
"""
exported_program = program.exported_program
if exported_program is None:
raise ValueError(
"The ONNXProgram does not contain an ExportedProgram. "
"To verify the ONNX program, initialize ONNXProgram with an ExportedProgram, "
"or assign the ExportedProgram to the ONNXProgram.exported_program attribute."
)
if args is None and kwargs is None:
# User did not provide example inputs, use the default example inputs
if exported_program.example_inputs is None:
raise ValueError(
"No example inputs provided and the exported_program does not contain example inputs. "
"Please provide arguments to verify the ONNX program."
)
args, kwargs = exported_program.example_inputs
if args is None:
args = ()
if kwargs is None:
kwargs = {}
torch_module = exported_program.module()
torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs))
# ONNX outputs are always real, so we need to convert torch complex outputs to real representations
torch_outputs = [
torch.view_as_real(output) if torch.is_complex(output) else output
for output in torch_outputs
]
onnx_outputs = program(*args, **kwargs)
# TODO(justinchuby): Include output names in the error message
torch.testing.assert_close(
tuple(onnx_outputs),
tuple(torch_outputs),
rtol=rtol,
atol=atol,
equal_nan=True,
check_device=False,
)
|