1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
|
# mypy: allow-untyped-defs
# NOTE: This file is referenced by name at
# /opt/pytorch/torch/_dynamo/eval_frame.py::DONT_WRAP_FILES.
# introduced by https://github.com/pytorch/pytorch/pull/98894.
# If this file is renamed, moved, etc please update the reference there!
from __future__ import annotations
import contextlib
import functools
import inspect
from typing import Any, Callable, Mapping, Sequence
import torch._dynamo
import torch.export as torch_export
import torch.fx
import torch.onnx
from torch.onnx._internal import _exporter_legacy, io_adapter
from torch.utils import _pytree as pytree
class _PyTreeExtensionContext:
"""Context manager to register PyTree extension."""
_extensions: dict[type, tuple[pytree.FlattenFunc, pytree.UnflattenFunc]]
def __init__(self) -> None:
self._extensions = {}
# Register PyTree extension for HuggingFace model output.
self._register_huggingface_model_output_extension()
def __enter__(self):
for class_type, (flatten_func, unflatten_func) in self._extensions.items():
pytree._private_register_pytree_node(
class_type,
flatten_func,
unflatten_func,
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for class_type in self._extensions:
pytree.SUPPORTED_NODES.pop(class_type)
def register_pytree_node(
self,
class_type: type,
flatten_func: pytree.FlattenFunc,
unflatten_func: pytree.UnflattenFunc,
):
"""Register PyTree extension for a custom python type.
Args:
class_type: The custom python type.
flatten_func: The flatten function.
unflatten_func: The unflatten function.
Raises:
AssertionError: If the custom python type is already registered.
"""
if class_type in pytree.SUPPORTED_NODES or class_type in self._extensions:
# PyTree node already registered.
# E.g., `huggingface/transformer` registers `ModelOutput` as PyTree node after
# https://github.com/huggingface/transformers/pull/25358.
return
self._extensions[class_type] = (flatten_func, unflatten_func)
def _register_huggingface_model_output_extension(self):
try:
from transformers import modeling_outputs # type: ignore[import]
except ImportError:
return
def model_output_flatten(
output: modeling_outputs.ModelOutput,
) -> tuple[list[Any], pytree.Context]:
return list(output.values()), (type(output), list(output.keys()))
def model_output_unflatten(
values: list[Any], context: pytree.Context
) -> modeling_outputs.ModelOutput:
output_type, keys = context
return output_type(**dict(zip(keys, values)))
# All 'ModelOutput' subclasses are defined under module 'modeling_outputs'.
named_model_output_classes = inspect.getmembers(
modeling_outputs,
lambda x: (
inspect.isclass(x)
and issubclass(x, modeling_outputs.ModelOutput)
and x is not modeling_outputs.ModelOutput
),
)
for _, class_type in named_model_output_classes:
self.register_pytree_node(
class_type,
model_output_flatten,
model_output_unflatten, # type: ignore[arg-type ]
)
class DynamoFlattenOutputStep(io_adapter.FlattenOutputStep):
"""Flatten nested collection and custom python types and return a flat list of elements.
Extended from :class:`io_adapter.FlattenOutputStep` to support flattening arbitrary
types via pytree extension. By default this supports many common user defined python
types such as :class:`ModelOutput` from HuggingFace transformers.
The pytree extension can be customized by passing in a ``_PyTreeExtensionContext``
object. See :meth:`_PyTreeExtensionContext.register_pytree_node`.
"""
def __init__(self, pytree_extension_context: _PyTreeExtensionContext | None = None):
super().__init__()
self._pytree_extension_context = (
pytree_extension_context or _PyTreeExtensionContext()
)
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Sequence[Any]:
"""Flatten the model outputs, under the context of pytree extension."""
with self._pytree_extension_context:
return super().apply(model_outputs, model=model)
def _wrap_model_with_output_adapter(
model: torch.nn.Module | Callable,
output_adapter: DynamoFlattenOutputStep,
) -> Callable:
"""Wrap model with output adapter.
This is a helper function to enable :func:`dynamo.export` on models that produce
custom user defined types outputs. It wraps the model with an output adapter to
convert the outputs to :func:`dynamo.export` compatible types, i.e. :class:`torch.Tensor`.
The adapting logic is controlled by ``output_adapter``.
Args:
model: PyTorch model or function.
output_adapter: Output adapter to apply to model output.
Returns:
Wrapped model.
"""
model_func = model.forward if isinstance(model, torch.nn.Module) else model
# Preserve original function signature.
@functools.wraps(model_func)
def wrapped(*args, **kwargs):
return output_adapter.apply(model_func(*args, **kwargs), model=model)
return wrapped
class DynamoExport(_exporter_legacy.FXGraphExtractor):
"""Generates a FX GraphModule using torch.dynamo.export API
Args:
aten_graph: If True, exports a graph with ATen operators.
If False, exports a graph with Python operators.
"""
def __init__(
self,
aten_graph: bool | None = None,
):
super().__init__()
self.aten_graph = aten_graph or True
def generate_fx(
self,
options: _exporter_legacy.ResolvedExportOptions,
model: torch.nn.Module | Callable,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
) -> torch.fx.GraphModule:
# `dynamo.export` does not recognize custom user defined classes as output type.
# Apply wrapper to adapt the outputs back to `dynamo.export` compatible types,
# i.e. :class:`torch.Tensor`.
dynamo_flatten_output_step = DynamoFlattenOutputStep()
wrapped_model = _wrap_model_with_output_adapter(
model, dynamo_flatten_output_step
)
# Record the output adapter step.
self.output_adapter.append_step(dynamo_flatten_output_step)
# Translate callable to FX graph.
#
fake_mode = (
options.fake_context.fake_mode
if options.fake_context
else contextlib.nullcontext()
)
fx_mode = "symbolic" if options.dynamic_shapes else "fake"
with fake_mode: # type: ignore[attr-defined]
graph_module, graph_guard = torch._dynamo.export(
wrapped_model,
tracing_mode=fx_mode,
)(
*model_args,
**model_kwargs,
)
del graph_guard # Unused
torch._dynamo.reset()
# Export FX graph to ONNX ModelProto.
self.input_adapter.append_step(
io_adapter.FlattenInputWithTreeSpecValidationInputStep()
)
updated_model_args = self.input_adapter.apply(
*model_args, model=model, **model_kwargs
)
return self.pre_export_passes(options, model, graph_module, updated_model_args) # type: ignore[return-value]
def pre_export_passes(
self,
options: _exporter_legacy.ResolvedExportOptions,
original_model: torch.nn.Module | Callable,
fx_module: torch.fx.GraphModule,
fx_module_args: Sequence[Any],
):
return _exporter_legacy.common_pre_export_passes(
options, original_model, fx_module, fx_module_args
)
|