1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
|
# mypy: allow-untyped-defs
from __future__ import annotations
import abc
import contextlib
import dataclasses
import difflib
import io
import logging
import sys
from typing import Any, Callable, TYPE_CHECKING
import torch
import torch.fx
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
if TYPE_CHECKING:
from torch._subclasses import fake_tensor
@dataclasses.dataclass
class PackageInfo:
package_name: str
version: str | None
commit_hash: str | None
def to_onnx_domain_string(self) -> str:
return ".".join(
filter(None, ("pkg", self.package_name, self.version, self.commit_hash))
)
@classmethod
def from_python_class(cls, python_class_name: type | str) -> PackageInfo:
if isinstance(python_class_name, type):
python_class_name = python_class_name.__module__
package_name = python_class_name.split(".")[0]
package = __import__(package_name)
version = getattr(package, "__version__", None)
# TODO: Figure out how to retrieve commit hash.
commit_hash = None
return cls(package_name, version, commit_hash)
@dataclasses.dataclass
class GraphModuleOnnxMeta:
package_info: PackageInfo
@contextlib.contextmanager
def _patch_difflib_sequence_matcher_init():
"""Context patching `difflib.SequenceMatcher` for fx readable graph.
Under this context, the `autojunk` argument of `difflib.SequenceMatcher` will always
be considered as `False`. This is to prevent `difflib.SequenceMatcher` recognizing
stacktrace messages in fx readable graph as junk, as these messages tend to be long (>200)
and repeat multiple times, which falls under the junk filter criteria.
`difflib.SequenceMatcher` is used underneath by all sorts of diffing functions
in `difflib`, including `difflib.unified_diff`, `difflib.ndiff`, `difflib.context_diff`.
Unfortunately, there is no way to pass `autojunk` argument to these functions, and
they all default to `True`. This context patching will affect all of them.
`Reference: Automatic junk heuristic <https://docs.python.org/3/library/difflib.html>`_
"""
original_init = difflib.SequenceMatcher.__init__
def patched_init(self, isjunk=None, a="", b="", autojunk=True):
original_init(self, isjunk, a, b, autojunk=False)
difflib.SequenceMatcher.__init__ = patched_init # type: ignore[assignment]
try:
yield
finally:
difflib.SequenceMatcher.__init__ = original_init # type: ignore[assignment]
def _unified_diff(a: str, b: str) -> str:
"""Return a string containing the unified diff of two strings.
This function calls a patched version of `difflib.unified_diff` with `autojunk` set
to `False` for `difflib.SequenceMatcher` class. More details can be found in
`_patch_difflib_sequence_matcher_init` function.
Args:
a: The first string.
b: The second string.
Returns:
The unified diff of the two strings. If there is no diff, return "<no diff>".
Example::
>>> a = '''class GraphModule(torch.nn.Module):
... def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor):
... # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1])
... view = input_ids.view(-1, 3); input_ids = None
... '''
>>> b = '''class <lambda>(torch.nn.Module):
... def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]):
... # File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1])
... view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]); input_ids = None
... '''
>>> print(_unified_diff(a, b))
---
+++
@@ -1,4 +1,4 @@
-class GraphModule(torch.nn.Module):
- def forward(self, input_ids : torch.Tensor, attention_mask : torch.Tensor):
+class <lambda>(torch.nn.Module):
+ def forward(self, input_ids: i64[1, 3], attention_mask: i64[1, 3]):
# File: /modeling.py:770, code: input_ids = input_ids.view(-1, input_shape[-1])
- view = input_ids.view(-1, 3); input_ids = None
+ view: i64[1, 3] = torch.ops.aten.view.default(input_ids, [-1, 3]); input_ids = None
"""
a_list = a.splitlines(keepends=True)
b_list = b.splitlines(keepends=True)
with _patch_difflib_sequence_matcher_init():
# Set `n` to `sys.maxsize` to show entire graph when there is a diff.
diff = "".join(difflib.unified_diff(a_list, b_list, n=sys.maxsize))
if not diff:
return "<no diff>"
return diff
def _transform_diagnose_call_message_formatter(
run: Callable,
self: Transform,
*args: Any,
**kwargs: Any,
) -> str:
return f"Running {self.__class__.__name__} pass. "
def maybe_fx_graph_tabular(graph: torch.fx.Graph) -> str | None:
"""Return the Graph nodes in tabular format. Equivalent to stdout of `graph.print_tabular()`.
If `tabulate` is not installed, return `None`.
Args:
graph: The Graph to print.
Returns:
The Graph printed in a tabular format. None if `tabulate` is not installed.
"""
f = io.StringIO()
with contextlib.redirect_stdout(f):
try:
graph.print_tabular()
except ImportError:
return None
return f.getvalue()
class Transform(abc.ABC):
"""Base class for FX graph transformations to be used by FX-ONNX exporter.
Similar to `FX Interpreter <https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter>`_,
specializations of this class execute the FX graph Node-by-Node.
Methods in the `Transform` class can be overridden to customize the behavior of the model.
This pattern can be useful for many things, including writing code transformations as well as analysis passes.
The following methods can be overridden::
_run()
+-- run_node()
+-- placeholder()
+-- get_attr()
+-- call_function()
+-- call_method()
+-- call_module()
+-- output()
One important aspect to note is that if the transformation modifies the model input and/or output signature,
(e.g. additional inputs/outputs are added to the model), :class:`InputAdaptStep` and/or :class:`OutputAdaptStep`
are needed to reconcile :attr:`ONNXProgram.model_proto`.
That is, the model signature and the model representation must match.
As an additional feature, this class provides builtin support for transformation recording using the diagnostics.
The granularity of overriding is up to the user. And it affects the granularity of
the diagnostics information. For example, if `_run()` is overridden, the
diagnostics information will only contain graph level transformation. Instead,
if `call_function()` is overridden, the diagnostics information will additionally
contain the node level information of `call_function()`.
TODO(bowbao): Add more overridable methods in call hierarchy
TODO(bowbao): Create an example once more overridable methods are added.
"""
diagnostic_context: diagnostics.DiagnosticContext
"""The diagnostic context for recording diagnostics."""
module: torch.fx.GraphModule
"""The module to be transformed."""
fake_mode: fake_tensor.FakeTensorMode | None
"""The existing fake mode detected from `self.module`."""
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
module: torch.fx.GraphModule,
):
"""Initialize the transform.
Args:
diagnostic_context: The diagnostic context for recording diagnostics.
module: The module to be transformed.
"""
self.diagnostic_context = diagnostic_context
self.module = module
self.fake_mode = self._detect_fake_mode()
def _detect_fake_mode(self) -> fake_tensor.FakeTensorMode | None:
"""Detect fake mode from the graph.
Scan through all nodes in graph and their meta['val'] to detect fake mode.
"""
fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes]
with unset_fake_temporarily():
return torch._dynamo.utils.detect_fake_mode(fake_tensors)
def _maybe_fakefy_args(
self, fake_mode: fake_tensor.FakeTensorMode | None, *args: Any
) -> tuple[Any, ...]:
if fake_mode is None:
return args
# NB: This should hit the cache if tensors were fakefied before.
# E.g., when the fx graph is produced by Dynamo.
return tuple(
fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args
)
@abc.abstractmethod
def _run(self, *args, **kwargs) -> torch.fx.GraphModule: ...
@diagnostics.diagnose_call(
diagnostics.rules.fx_pass,
diagnostic_message_formatter=_transform_diagnose_call_message_formatter,
)
def run(self, *args, **kwargs) -> torch.fx.GraphModule:
"""Run the transform on `self.module`.
Note that this method may or may not mutate `self.module`, and the returned
`GraphModule` could be either `self.module` or a new `GraphModule`.
Args:
*args: Positional arguments for `self.module` to run.
**kwargs: Keyword arguments for `self.module` to run.
"""
diagnostic = self.diagnostic_context.inflight_diagnostic(
rule=diagnostics.rules.fx_pass
)
diagnostic.info(
"For detailed logging of graph modifications by this pass, either set "
"`DiagnosticOptions.verbosity_level` to `logging.DEBUG` or use the environment variable "
"`TORCH_LOGS='onnx_diagnostics'`."
)
# Gather graph information before transform.
graph_diff_log_level = logging.DEBUG
if diagnostic.logger.isEnabledFor(graph_diff_log_level):
# Cannot use LazyString because the graph may have been mutated at evaluation time.
old_readable_graph = self.module.print_readable(print_output=False)
old_tabular = maybe_fx_graph_tabular(self.module.graph)
else:
# Set to empty string to avoid unbound warning. This value should never be
# used since the log level is not enabled.
old_readable_graph = ""
old_tabular = ""
module = self._run(*args, **kwargs)
# Gather graph information after transform.
if diagnostic.logger.isEnabledFor(graph_diff_log_level):
new_readable_graph = module.print_readable(print_output=False)
new_tabular = maybe_fx_graph_tabular(module.graph)
with diagnostic.log_section(graph_diff_log_level, "Graph diff:"):
diagnostic.log(
graph_diff_log_level,
"```\n%s\n```",
diagnostics.LazyString(
_unified_diff, old_readable_graph, new_readable_graph
),
)
with diagnostic.log_section(graph_diff_log_level, "Tabular diff:"):
if old_tabular is None or new_tabular is None:
diagnostic.log(
graph_diff_log_level,
"Tabular diff is not available because `tabulate` is not installed.",
)
else:
diagnostic.log(
graph_diff_log_level,
"```\n%s\n```",
diagnostics.LazyString(_unified_diff, old_tabular, new_tabular),
)
return module
class AnalysisResult(abc.ABC): # noqa: B024
...
class Analysis(abc.ABC):
def __init__(
self,
diagnostic_context: diagnostics.DiagnosticContext,
module: torch.fx.GraphModule,
onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
):
self.diagnostic_context = diagnostic_context
self.module = module
self.onnxfunction_dispatcher = onnxfunction_dispatcher
@abc.abstractmethod
def analyze(self, diagnostic_level: diagnostics.infra.Level) -> AnalysisResult: ...
|