1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
|
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
import re
from typing import TYPE_CHECKING
from torch.onnx._internal.exporter import _analysis, _registration, _verification
if TYPE_CHECKING:
import os
from onnxscript import ir
import torch
@dataclasses.dataclass
class ExportStatus:
# Whether torch.export.export.export() succeeds
torch_export: bool | None = None
# Whether torch.export.export.export(..., strict=False) succeeds
torch_export_non_strict: bool | None = None
# Whether torch.jit.trace succeeds
torch_jit: bool | None = None
# Whether decomposition succeeds
decomposition: bool | None = None
# Whether ONNX translation succeeds
onnx_translation: bool | None = None
# Whether ONNX model passes onnx.checker.check_model
onnx_checker: bool | None = None
# Whether ONNX model runs successfully with ONNX Runtime
onnx_runtime: bool | None = None
# Whether the output of the ONNX model is accurate
output_accuracy: bool | None = None
def _status_emoji(status: bool | None) -> str:
if status is None:
return "⚪"
return "✅" if status else "❌"
def _format_export_status(status: ExportStatus) -> str:
return (
f"```\n"
f"{_status_emoji(status.torch_export_non_strict)} Obtain model graph with `torch.export.export(..., strict=False)`\n"
f"{_status_emoji(status.torch_export)} Obtain model graph with `torch.export.export(..., strict=True)`\n"
f"{_status_emoji(status.torch_jit)} Obtain model graph with `torch.jit.trace`\n"
f"{_status_emoji(status.decomposition)} Decompose operators for ONNX compatibility\n"
f"{_status_emoji(status.onnx_translation)} Translate the graph into ONNX\n"
f"{_status_emoji(status.onnx_checker)} Run `onnx.checker` on the ONNX model\n"
f"{_status_emoji(status.onnx_runtime)} Execute the model with ONNX Runtime\n"
f"{_status_emoji(status.output_accuracy)} Validate model output accuracy\n"
f"```\n\n"
)
def _strip_color_from_string(text: str) -> str:
# This regular expression matches ANSI escape codes
# https://github.com/pytorch/pytorch/blob/9554a9af8788c57e1c5222c39076a5afcf0998ae/torch/_dynamo/utils.py#L2785-L2788
ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]")
return ansi_escape.sub("", text)
def _format_exported_program(exported_program: torch.export.ExportedProgram) -> str:
# Adapted from https://github.com/pytorch/pytorch/pull/128476
# to remove colors
# Even though we can call graph_module.print_readable directly, since the
# colored option was added only recently, we can't guarantee that the
# version of PyTorch used by the user has this option. Therefore, we
# still call str(ExportedProgram)
text = f"```python\n{_strip_color_from_string(str(exported_program))}\n```\n\n"
return text
def construct_report_file_name(timestamp: str, status: ExportStatus) -> str:
# Status could be None. So we need to check for False explicitly.
if not (status.torch_export or status.torch_export_non_strict or status.torch_jit):
# All strategies failed
postfix = "pt_export"
elif status.decomposition is False:
postfix = "decomp"
elif status.onnx_translation is False:
postfix = "conversion"
elif status.onnx_checker is False:
postfix = "checker"
elif status.onnx_runtime is False:
postfix = "runtime"
elif status.output_accuracy is False:
postfix = "accuracy"
elif status.torch_export is False or status.torch_export_non_strict is False:
# Some strategies failed
postfix = "strategies"
else:
postfix = "success"
return f"onnx_export_{timestamp}_{postfix}.md"
def format_decomp_comparison(
pre_decomp_unique_ops: set[str],
post_decomp_unique_ops: set[str],
) -> str:
"""Format the decomposition comparison result.
Args:
unique_ops_in_a: The unique ops in the first program.
unique_ops_in_b: The unique ops in the second program.
Returns:
The formatted comparison result.
"""
return (
f"Ops exist only in the ExportedProgram before decomposition: `{sorted(pre_decomp_unique_ops)}`\n\n"
f"Ops exist only in the ExportedProgram after decomposition: `{sorted(post_decomp_unique_ops)}`\n"
)
def format_verification_infos(
verification_infos: list[_verification.VerificationInfo],
) -> str:
"""Format the verification result.
Args:
verification_infos: The verification result.
Returns:
The formatted verification result.
"""
return "\n".join(
f"`{info.name}`: `max_abs_diff={info.max_abs_diff:e}`, `max_rel_diff={info.max_rel_diff:e}`, "
f"`abs_diff_hist={info.abs_diff_hist}`, `rel_diff_hist={info.rel_diff_hist}`"
for info in verification_infos
)
def create_torch_export_error_report(
filename: str | os.PathLike,
formatted_traceback: str,
*,
export_status: ExportStatus,
profile_result: str | None,
):
with open(filename, "w", encoding="utf-8") as f:
f.write("# PyTorch ONNX Conversion Error Report\n\n")
f.write(_format_export_status(export_status))
f.write("Error message:\n\n")
f.write("```pytb\n")
f.write(formatted_traceback)
f.write("```\n\n")
if profile_result is not None:
f.write("## Profiling result\n\n")
f.write("```\n")
f.write(profile_result)
f.write("```\n")
def create_onnx_export_report(
filename: str | os.PathLike,
formatted_traceback: str,
program: torch.export.ExportedProgram,
*,
decomp_comparison: str | None = None,
export_status: ExportStatus,
profile_result: str | None,
model: ir.Model | None = None,
registry: _registration.ONNXRegistry | None = None,
verification_result: str | None = None,
):
with open(filename, "w", encoding="utf-8") as f:
f.write("# PyTorch ONNX Conversion Report\n\n")
f.write(_format_export_status(export_status))
f.write("## Error messages\n\n")
f.write("```pytb\n")
f.write(formatted_traceback)
f.write("\n```\n\n")
f.write("## Exported program\n\n")
f.write(_format_exported_program(program))
if model is not None:
f.write("## ONNX model\n\n")
f.write("```python\n")
f.write(str(model))
f.write("\n```\n\n")
f.write("## Analysis\n\n")
_analysis.analyze(program, file=f, registry=registry)
if decomp_comparison is not None:
f.write("\n## Decomposition comparison\n\n")
f.write(decomp_comparison)
f.write("\n")
if verification_result is not None:
f.write("\n## Verification results\n\n")
f.write(verification_result)
f.write("\n")
if profile_result is not None:
f.write("\n## Profiling result\n\n")
f.write("```\n")
f.write(profile_result)
f.write("```\n")
|