1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
|
# mypy: allow-untyped-defs
import copy
import dataclasses
import functools
import io
import json
import logging
import os
import re
import sys
import types
import warnings
import weakref
import zipfile
from collections import OrderedDict
from contextlib import contextmanager
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
import torch
import torch.fx
import torch.utils._pytree as pytree
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import compile_context
from torch._utils_internal import log_export_usage
from torch.export._tree_utils import reorder_kwargs
from torch.export.graph_signature import (
ArgumentSpec,
ConstantArgument,
ExportGraphSignature,
InputKind,
InputSpec,
OutputKind,
OutputSpec,
SymIntArgument,
SymBoolArgument,
SymFloatArgument,
TensorArgument,
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules
from .utils import _materialize_cpp_cia_ops
log = logging.getLogger(__name__)
@dataclasses.dataclass
class ExportDynamoConfig:
"""
Manage Export-specific configurations of Dynamo.
"""
allow_rnn: bool = True
# We only want to print this once to avoid flooding logs in workflows where capture_pre_autograd_graph
# is called multiple times.
@lru_cache
def capture_pre_autograd_graph_warning():
from torch._inductor import config
log.warning("+============================+")
log.warning("| !!! WARNING !!! |")
log.warning("+============================+")
log.warning("capture_pre_autograd_graph() is deprecated and doesn't provide any function guarantee moving forward.")
log.warning("Please switch to use torch.export.export_for_training instead.")
if config.is_fbcode():
log.warning("For unittest, capture_pre_autograd_graph() will fallback to torch.export.export_for_training.") # noqa: B950
@lru_cache
def print_export_warning():
log.warning("Using torch.export.export_for_training(...,strict=True)")
def gm_using_training_ir(graph_module: torch.fx.GraphModule) -> bool:
"""
Returns true if the graph module is detected to use training IR.
This function checks for two specific conditions within the nodes of the graph module:
1. The presence of the `torch.ops.aten.batch_norm.default` operation which indicates the use of training IR.
2. The presence of deprecated IR tags on node meta or batch norm ops produced by the deprecated IR.
The function raises a RuntimeError if both conditions are met, indicating a conflict in the IR.
"""
# TODO: clean up this code after training IR migration.
# T199018392
has_training_ir_batch_norm = False
has_deprecated_ir_tag = getattr(graph_module, "capture_pre_autograd_graph_tag", False)
for node in graph_module.graph.nodes:
if node.op == "call_function":
if node.target == torch.ops.aten.batch_norm.default:
has_training_ir_batch_norm = True
if node.meta.get("capture_pre_autograd_graph_tag", False):
has_deprecated_ir_tag = True
if node.target in [
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.miopen_batch_norm.default,
]:
has_deprecated_ir_tag = True
if has_deprecated_ir_tag and has_training_ir_batch_norm:
raise RuntimeError("Conflicting IR detected.")
return has_training_ir_batch_norm or not has_deprecated_ir_tag
@compatibility(is_backward_compatible=False)
def capture_pre_autograd_graph(
f: torch.nn.Module,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
) -> torch.nn.Module:
"""
A helper function that is intended to trace a module before any pre-autograd
decomposition is run. The produced module will be "non-functional" and
composed of aten operators. Later this API will be deleted in favor of more general
torch.export API.
Args:
f: nn.Module to be traced
args: example positional inputs.
kwargs: optional example keyword inputs.
dynamic_shapes: Should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
Returns:
An nn.Module containing the traced method.
"""
from torch.export._trace import _extract_fake_inputs, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
from torch._utils_internal import capture_pre_autograd_graph_using_training_ir
from torch._export.non_strict_utils import make_constraints
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._unlift import _create_stateful_graph_module
from torch.export.dynamic_shapes import _combine_args
capture_pre_autograd_graph_warning()
if sys.platform == "win32":
raise RuntimeError("capture_pre_autograd_graph not yet supported on Windows")
assert isinstance(f, torch.nn.Module), "Expected an nn.Module instance."
if kwargs is None:
kwargs = {}
if capture_pre_autograd_graph_using_training_ir():
print_export_warning()
module = torch.export.export_for_training(f, args, kwargs, dynamic_shapes=dynamic_shapes, strict=True).module()
else:
log_export_usage(event="export.private_api", flags={"capture_pre_autograd_graph"})
# Do not decompose dropout for exported models, because in eval mode the dropout
# op disappears from the graph, which makes it difficult to switch to train mode.
# See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
# We force create native_batch_norm because the below materialization logic
# only applies to CIA ops.
maybe_aliasing_or_mutating_ops = [torch.ops.aten.native_batch_norm.default]
_materialize_cpp_cia_ops()
for op in torch.ops.aten:
op_obj = getattr(torch.ops.aten, op)
for overload in op_obj.overloads():
op_overload = getattr(op_obj, overload)
if torch.Tag.maybe_aliasing_or_mutating in op_overload.tags:
maybe_aliasing_or_mutating_ops.append(op_overload)
decomp_table = {
op: op.decompose
for op in maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
assume_static_by_default=True,
tracing_mode="symbolic",
decomposition_table=decomp_table,
pre_dispatch=True,
aten_graph=True,
_log_export_usage=False,
)(
*args,
**kwargs,
)[0]
_, _, fake_mode = _extract_fake_inputs(m, args, kwargs)
m.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
if re.match(r"^[if]\d+$", str(k))
}
if isinstance(f, torch.nn.Module):
from torch.export._trace import _restore_state_dict
_restore_state_dict(f, m)
flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)
module = _create_stateful_graph_module(
m,
range_constraints=range_constraints,
)
setattr(module, "capture_pre_autograd_graph_tag", True) # noqa: B010
for node in module.graph.nodes:
node.meta["capture_pre_autograd_graph_tag"] = True
error_message = \
"""
Calling train() or eval() is not supported for exported models.
Alternatively, you may override these methods to do custom user behavior as follows:
def _my_train(self, mode: bool = True):
...
def _my_eval(self):
...
model.train = types.MethodType(_my_train, model)
model.eval = types.MethodType(_my_eval, model)
"""
def _train(self, mode: bool = True):
raise NotImplementedError(error_message)
def _eval(self, mode: bool = True):
raise NotImplementedError(error_message)
module.train = types.MethodType(_train, module) # type: ignore[method-assign]
module.eval = types.MethodType(_eval, module) # type: ignore[method-assign]
# Remove Proxy because they cannot be deepcopied or pickled.
if hasattr(module, "_buffers"):
torch._export.utils.remove_proxy_from_state_dict(
module._buffers, in_place=True
)
return module
# We only want to print this once to avoid flooding logs in workflows where aot_compile_warning
# is called multiple times.
@lru_cache
def aot_compile_warning():
from torch._inductor import config
log.warning("+============================+")
log.warning("| !!! WARNING !!! |")
log.warning("+============================+")
log.warning(
"torch._export.aot_compile()/torch._export.aot_load() is being deprecated, please switch to "
"directly calling torch._inductor.aoti_compile_and_package(torch.export.export())/"
"torch._inductor.aoti_load_package() instead.")
def aot_compile(
f: Callable,
args: Tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
same_signature: bool = True,
) -> Union[List[str], str]:
"""
Note: this function is not stable yet
Traces either an nn.Module's forward function or just a callable with PyTorch
operations inside, generates executable cpp code from the program, and returns
the path to the generated shared library
Args:
f: the `nn.Module` or callable to trace.
args: example positional inputs.
kwargs: optional example keyword inputs.
dynamic_shapes: Should either be:
1) a dict from argument names of ``f`` to their dynamic shape specifications,
2) a tuple that specifies dynamic shape specifications for each input in original order.
If you are specifying dynamism on keyword args, you will need to pass them in the order that
is defined in the original function signature.
The dynamic shape of a tensor argument can be specified as either
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is
not required to include static dimension indices in this dict, but when they are,
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None,
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions
are denoted by None. Arguments that are dicts or tuples / lists of tensors are
recursively specified by using mappings or sequences of contained specifications.
options: A dictionary of options to control inductor
disable_constraint_solver: Whether the dim constraint solver must be disabled.
Returns:
Path to the generated shared library
"""
from torch.export._trace import _export_to_torch_ir
from torch._inductor.decomposition import select_decomp_table
from torch._inductor import config
aot_compile_warning()
if config.is_predispatch:
gm = torch.export._trace._export(f, args, kwargs, dynamic_shapes, pre_dispatch=True).module()
else:
# We want to export to Torch IR here to utilize the pre_grad passes in
# inductor, which run on Torch IR.
gm = _export_to_torch_ir(
f,
args,
kwargs,
dynamic_shapes,
disable_constraint_solver=disable_constraint_solver,
same_signature=same_signature,
# Disabling this flag, because instead we can rely on the mapping
# dynamo_flat_name_to_original_fqn which is coming from Dynamo.
restore_fqn=False,
)
with torch.no_grad():
so_path = torch._inductor.aot_compile(gm, args, kwargs, options=options) # type: ignore[arg-type]
return so_path
def aot_load(so_path: str, device: str) -> Callable:
"""
Loads a shared library generated by aot_compile and returns a callable
Args:
so_path: Path to the shared library
Returns:
A callable
"""
aot_compile_warning()
if device == "cpu":
runner = torch._C._aoti.AOTIModelContainerRunnerCpu(so_path, 1) # type: ignore[call-arg]
elif device == "cuda" or device.startswith("cuda:"):
runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device) # type: ignore[assignment, call-arg]
elif device == "xpu" or device.startswith("xpu:"):
runner = torch._C._aoti.AOTIModelContainerRunnerXpu(so_path, 1, device) # type: ignore[assignment, call-arg]
else:
raise RuntimeError("Unsupported device " + device)
def optimized(*args, **kwargs):
call_spec = runner.get_call_spec() # type: ignore[attr-defined]
in_spec = pytree.treespec_loads(call_spec[0])
out_spec = pytree.treespec_loads(call_spec[1])
flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0]
flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)]
flat_outputs = runner.run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
return optimized
|