1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
|
import inspect
import logging
import os
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch._logging._internal
import torch._logging.structured
from torch._export.passes.insert_custom_op_guards import insert_custom_op_guards
from torch.export import ExportedProgram
from torch.export._trace import _export
from torch.export.dynamic_shapes import refine_dynamic_shapes_from_suggested_fixes
log = logging.getLogger(__name__)
class FailureType(IntEnum):
MISSING_FAKE_KERNEL = 1
DATA_DEPENDENT_ERROR = 2
CONSTRAINT_VIOLATION_ERROR = 3
MISMATCHED_FAKE_KERNEL = 4
def __str__(self) -> str:
return self.name
def prettify_stack(stack: List[Dict[str, str]], str_to_filename: Dict[str, str]) -> str:
res = ""
for frame in stack:
if frame["filename"] not in str_to_filename:
continue
res += f"""
File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}"""
return res
def filter_stack(
stack: List[Dict[str, str]], str_to_filename: Dict[str, str]
) -> List[Dict[str, str]]:
for i, s in enumerate(reversed(stack)):
s["filename"] = str(s["filename"])
if s["filename"] not in str_to_filename:
continue
torch_filepath = os.path.dirname(inspect.getfile(torch)) + os.path.sep
if torch_filepath not in str_to_filename[s["filename"]]:
return stack[len(stack) - i - 3 : len(stack) - i]
return stack[-3:]
def hash_stack(stack: List[Dict[str, str]]) -> str:
return ";".join(f'line: {s["line"]} filename: {s["filename"]}' for s in stack)
class FailureReport:
def __init__(
self, failure_type: FailureType, data: Dict[str, Any], xfail: bool = False
) -> None:
self.failure_type: FailureType = failure_type
self.data: Dict[str, Any] = data
self.xfail: bool = xfail
def __repr__(self) -> str:
return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})"
def print(self, str_to_filename: Dict[str, str]) -> str:
if self.failure_type == FailureType.MISSING_FAKE_KERNEL:
op = self.data["op"]
return f"""Missing fake kernel.
torch.ops.{op} is missing a fake kernel implementation.
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation.
""" # noqa: B950
elif self.failure_type == FailureType.CONSTRAINT_VIOLATION_ERROR:
return f"""Constraint violation error.
The specified input dynamic_shapes spec was found to be incorrect during tracing.
Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}.
This occured at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}.
Because of this, we have modified the dynamic shapes structure to be the
following. You can also use torch.export.Dim.AUTO instead to specify your
dynamic shapes, and we will automatically infer the dynamism for you.
```
dynamic_shapes = {self.data["new_dynamic_shapes"]}
```
"""
elif self.failure_type == FailureType.DATA_DEPENDENT_ERROR:
return f"""Data dependent error.
When exporting, we were unable to figure out if the expression `{self.data["expr"]}` always holds.
This occurred at the following stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}.
As a result, it was specialized to evaluate to `{self.data["result"]}`, and asserts were inserted into the graph.
Please add `torch._check(...)` to the original code to assert this data-dependent assumption.
Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details.
""" # noqa: B950
elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL:
op = self.data["op"]
reason = self.data["reason"]
return f"""Mismatched fake kernel.
torch.ops.{op} has a fake kernel implementation, but it has incorrect behavior, based on the real kernel.
The reason for the mismatch is: {reason}.
Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation.
""" # noqa: B950
else:
raise ValueError(f"Unknown failure type: {self.failure_type}")
class DraftExportReport:
def __init__(self, failures: List[FailureReport], str_to_filename: Dict[str, str]):
self.failures: List[FailureReport] = failures
self.str_to_filename = str_to_filename
def successful(self) -> bool:
return len(self.failures) == 0 or all(
failure.xfail for failure in self.failures
)
def __repr__(self) -> str:
return f"DraftExportReport({self.failures})"
def __str__(self) -> str:
WARNING_COLOR = "\033[93m"
GREEN_COLOR = "\033[92m"
END_COLOR = "\033[0m"
if self.successful():
return f"""{GREEN_COLOR}
##############################################################################################
Congratuations: No issues are found during export, and it was able to soundly produce a graph.
You can now change back to torch.export.export()
##############################################################################################
{END_COLOR}"""
error = f"""{WARNING_COLOR}
###################################################################################################
WARNING: {len(self.failures)} issue(s) found during export, and it was not able to soundly produce a graph.
Please follow the instructions to fix the errors.
###################################################################################################
"""
for i, failure in enumerate(self.failures):
error += f"{i + 1}. {failure.print(self.str_to_filename)}\n"
error += END_COLOR
return error
def apply_suggested_fixes(self) -> None:
raise NotImplementedError("Not implemented yet")
class CaptureStructuredTrace(logging.Handler):
def __init__(self, specific_log_keys: List[str]):
super().__init__()
self.specific_log_keys = specific_log_keys
self.logs: List[Tuple[str, Dict[str, Any]]] = []
self.logger = logging.getLogger("torch.__trace")
self.prev_get_dtrace = False
def __enter__(self) -> "CaptureStructuredTrace":
self.logs = []
self.logger.addHandler(self)
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
torch._logging._internal.GET_DTRACE_STRUCTURED = True
return self
def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[no-untyped-def]
self.logs = []
self.logger.removeHandler(self)
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
self.prev_get_dtrace = False
def emit(self, record: Any) -> None:
metadata = record.metadata
for key in self.specific_log_keys:
if key in metadata:
self.logs.append((key, metadata[key]))
def draft_export(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
preserve_module_call_signature: Tuple[str, ...] = (),
strict: bool = False,
pre_dispatch: bool = False,
) -> Tuple[ExportedProgram, DraftExportReport]:
kwargs = kwargs or {}
dynamic_shapes = dynamic_shapes or {}
capture_structured_log = CaptureStructuredTrace(
[
"propagate_real_tensors",
"guard_added",
"missing_fake_kernel",
"mismatched_fake_kernel",
]
)
with torch._functorch.config.patch(
fake_tensor_propagate_real_tensors=True,
generate_fake_kernels_from_real_mismatches=True,
), capture_structured_log:
try:
new_shapes = None
ep = _export(
mod,
args,
kwargs,
dynamic_shapes=dynamic_shapes,
strict=strict,
pre_dispatch=pre_dispatch,
preserve_module_call_signature=preserve_module_call_signature,
)
except torch._dynamo.exc.UserError as exc:
new_shapes = refine_dynamic_shapes_from_suggested_fixes(
exc.msg, dynamic_shapes
)
ep = _export(
mod,
args,
kwargs,
dynamic_shapes=new_shapes,
strict=strict,
pre_dispatch=pre_dispatch,
preserve_module_call_signature=preserve_module_call_signature,
)
str_to_filename: Dict[str, str] = {
str(v): k for (k, v) in torch._logging.structured.INTERN_TABLE.items()
}
failures: List[FailureReport] = []
custom_ops_logs: Dict[
Any, Tuple[Dict[str, Any], FailureType]
] = {} # Dedup custom ops
data_dependent_logs: Dict[
str, Dict[str, Any]
] = {} # Dedup data dependent errors based on stacktrace
for log_name, log_contents in capture_structured_log.logs:
failure_type = None
if log_name == "propagate_real_tensors":
log_contents["stack"] = filter_stack(
log_contents["stack"], str_to_filename
)
if hash_stack(log_contents["stack"]) in data_dependent_logs:
continue
data_dependent_logs[hash_stack(log_contents["stack"])] = log_contents
failure_type = FailureType.DATA_DEPENDENT_ERROR
elif log_name == "guard_added":
if new_shapes is None:
continue
failure_type = FailureType.CONSTRAINT_VIOLATION_ERROR
if len(log_contents["symbol_to_sources"]) == 0:
# We only want to include guards added that are relevant to
# the symbolic shapes corresponding to the inputs which were
# specified in the dynamic_shapes arg. These have a source.
continue
log_contents["stack"] = filter_stack(
log_contents["stack"], str_to_filename
)
log_contents["new_dynamic_shapes"] = new_shapes
elif log_name == "missing_fake_kernel":
if log_contents["op"] in custom_ops_logs:
continue
failure_type = FailureType.MISSING_FAKE_KERNEL
custom_ops_logs[log_contents["op"]] = (log_contents, failure_type)
elif log_name == "mismatched_fake_kernel":
if (log_contents["op"], log_contents["reason"]) in custom_ops_logs:
continue
failure_type = FailureType.MISMATCHED_FAKE_KERNEL
custom_ops_logs[(log_contents["op"], log_contents["reason"])] = (
log_contents,
failure_type,
)
else:
raise RuntimeError(f"Unknown log name: {log_name}")
assert failure_type is not None
failures.append(
FailureReport(
failure_type,
log_contents,
)
)
report = DraftExportReport(failures, str_to_filename)
# Add asserts around custom ops
insert_custom_op_guards(ep.graph_module, list(custom_ops_logs.keys()))
ep._report = report
if not report.successful():
log.warning(report)
return ep, report
|