1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
|
# mypy: allow-untyped-defs
from __future__ import annotations
from typing import (
Any,
Callable,
Mapping,
Protocol,
runtime_checkable,
Sequence,
TYPE_CHECKING,
)
import torch
import torch.export as torch_export
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
import inspect
# TODO(bowbao): Add diagnostics for IO adapters.
@runtime_checkable
class InputAdaptStep(Protocol):
"""A protocol that defines a step in the input adapting process.
The input adapting process is a sequence of steps that are applied to the
PyTorch model inputs to transform them into the inputs format expected by the
exported ONNX model. Each step takes the PyTorch model inputs as arguments and
returns the transformed inputs.
This serves as a base formalized construct for the transformation done to model
input signature by any individual component in the exporter.
"""
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]: ...
class InputAdapter:
"""A class that adapts the PyTorch model inputs to exported ONNX model inputs format."""
def __init__(self, steps: list[InputAdaptStep] | None = None):
self._steps = steps or []
def append_step(self, step: InputAdaptStep) -> None:
"""Appends a step to the input adapt steps.
Args:
step: The step to append.
"""
self._steps.append(step)
def apply(
self,
*model_args,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
**model_kwargs,
) -> Sequence[int | float | bool | str | torch.Tensor | torch.dtype | None]:
"""Converts the PyTorch model inputs to exported ONNX model inputs format.
Args:
model_args: The PyTorch model inputs.
model: The PyTorch model.
model_kwargs: The PyTorch model keyword inputs.
Returns:
A sequence of tensors converted from PyTorch model inputs.
"""
args: Sequence[Any] = model_args
kwargs: Mapping[str, Any] = model_kwargs
for step in self._steps:
args, kwargs = step.apply(args, kwargs, model=model)
assert not kwargs
return args
@runtime_checkable
class OutputAdaptStep(Protocol):
"""A protocol that defines a step in the output adapting process.
The output adapting process is a sequence of steps that are applied to the
PyTorch model outputs to transform them into the outputs format produced by the
exported ONNX model. Each step takes the PyTorch model outputs as arguments and
returns the transformed outputs.
This serves as a base formalized construct for the transformation done to model
output signature by any individual component in the exporter.
"""
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Any: ...
class OutputAdapter:
"""A class that adapts the PyTorch model outputs to exported ONNX model outputs format."""
def __init__(self, steps: list[OutputAdaptStep] | None = None):
self._steps = steps or []
def append_step(self, step: OutputAdaptStep) -> None:
"""Appends a step to the output format steps.
Args:
step: The step to append.
"""
self._steps.append(step)
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Sequence[torch.Tensor | int | float | bool | str]:
"""Converts the PyTorch model outputs to exported ONNX model outputs format.
Args:
model_outputs: The PyTorch model outputs.
model: The PyTorch model.
Returns:
PyTorch model outputs in exported ONNX model outputs format.
"""
for step in self._steps:
model_outputs = step.apply(model_outputs, model=model)
return model_outputs
# TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276
# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame.
class _DummyLeaf: # use a class instead.
pass
def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec:
def replace_list_with_tuple(x: Any) -> Any:
if type(x) is list:
return pytree.tree_map(
replace_list_with_tuple,
tuple(x),
is_leaf=lambda x: type(x) is list,
)
return x
dummy_leaf = _DummyLeaf()
dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec)
dummy_tree = pytree.tree_map(
replace_list_with_tuple,
dummy_tree,
is_leaf=lambda x: type(x) is list,
)
return pytree.tree_structure(dummy_tree)
def _open_top_level_sequence_if_single_element(
spec: pytree.TreeSpec,
) -> pytree.TreeSpec:
if spec.type in (tuple, list) and spec.num_children == 1:
return spec.children_specs[0]
return spec
def _assert_identical_pytree_spec(
spec1: pytree.TreeSpec, spec2: pytree.TreeSpec, error_message: str
) -> None:
"""Assert the two `TreeSpec` objects are identical.
Args:
spec1: The first `TreeSpec` object.
spec2: The second `TreeSpec` object.
error_message: The error message to raise if the two `TreeSpec` objects are not
identical.
Raises:
ValueError: If the two `TreeSpec` objects are not identical.
"""
# TODO(bowbao): Turn this check into diagnostic. Consider warning instead of error.
pass_if_any_checks: Sequence[Callable[[], bool]] = [
lambda: spec1 == spec2,
# FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'.
lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2),
# FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list.
lambda: _open_top_level_sequence_if_single_element(spec1) == spec2,
lambda: spec1 == _open_top_level_sequence_if_single_element(spec2),
]
if not any(check() for check in pass_if_any_checks):
raise ValueError(f"{error_message}\nExpect {spec1}.\nActual {spec2}.")
class BindInputStep(InputAdaptStep):
"""Bind the input arguments to the model signature."""
def __init__(self, model_signature: inspect.Signature):
self._model_signature = model_signature
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Bind the input arguments to the model signature.
We hope the input kwargs will be mapped to bound.args after binding.
If not, we will raise an error.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args and kwargs. args is always empty.
Raises:
ValueError: If there are keyword-only arguments left after binding args and
kwargs to model signature.
"""
bound = self._model_signature.bind(*model_args, **model_kwargs)
bound.apply_defaults()
# keyword-only arguments are not handled.
# bound.kwargs only contains keyword-only arguments after calling
# bind & apply_defaults, so we raise if it's not empty.
if bound.kwargs:
raise ValueError("Keyword-only arguments are not supported.")
return (), bound.arguments
class MergeKwargsIntoArgsInputStep(InputAdaptStep):
"""Merge the input kwargs into the input args."""
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Merge the input kwargs into the input args.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args and kwargs. kwargs is always empty.
"""
return tuple(model_args) + tuple(model_kwargs.values()), {}
class LiftParametersAndBuffersIntoArgsInputStep(InputAdaptStep):
"""Append parameters and buffers to model's positional argument list."""
def __init__(self, inputs: tuple[torch.Tensor, ...]) -> None:
self.inputs = inputs
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Append model's parameters and buffers into its input.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args + appended inputs and kwargs.
"""
return (*model_args, *self.inputs), model_kwargs
class ConvertComplexToRealRepresentationInputStep(InputAdaptStep):
"""Convert complex dtype tensors to real representation tensors.
ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
to real representation tensors (i.e., float dtype tensors with an extra dimension
representing the real and imaginary parts of the complex number).
"""
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Convert complex tensors to float tensors.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args and kwargs.
"""
return (
tuple(
torch.view_as_real(arg.resolve_conj())
if isinstance(arg, torch.Tensor) and arg.is_complex()
else arg
for arg in model_args
),
model_kwargs,
)
class RemoveNoneInputStep(InputAdaptStep):
"""Remove `None` from arguments.
This adapt step assumes ``model_kwargs`` is empty. It also assumes ``model_args``
is flattened, i.e. it does not check `None` inside nested collections.
"""
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Remove `None` from arguments.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args and kwargs.
Raises:
ValueError: If `model_kwargs` is not empty.
"""
assert not model_kwargs
return tuple(arg for arg in model_args if arg is not None), {}
class RemoveNonTensorInputStep(InputAdaptStep):
"""Remove the non-tensor input arguments.
Dynamo does not support non-tensor input arguments (https://github.com/pytorch/pytorch/issues/99534).
Specifically, it does put the input into graph with an empty node, but consumed by no ones.
The concrete value is embedded into the graph as a constant arg of a target node. Meta
suggests in this case that one should rewrite the model code to make it tensor if the
input value is supposed to change at runtime. We might need to further investigate
the feasibility of that suggestion.
For example,
def func(x, b=1.0):
y = x + b
z = y.relu()
return (y, z)
x = torch.randn(1, 1, 2, dtype=torch.float32)
gm_fun, _ = dynamo.export(func, x, b=8.0, aten_graph=True, tracing_mode="real")
# class GraphModule(torch.nn.Module):
# def forward(self, x, b):
# arg0: f32[1, 1, 2], arg1, = fx_pytree.tree_flatten_spec(([x, b], {}), self._in_spec)
# # File: path/to/pytorch/test_constant_input.py:5, code: y = x + b
# add_tensor: f32[1, 1, 2] = torch.ops.aten.add.Tensor(arg0, 8.0); arg0 = None
# # File: path/to/pytorch/test_constant_input.py:6, code: z = y.relu()
# relu_default: f32[1, 1, 2] = torch.ops.aten.relu.default(add_tensor)
# return pytree.tree_unflatten([add_tensor, relu_default], self._out_spec)
Empty torch.fx.Node input leading to a mismatched number of input with PyTorch, as
it's ignored in ONNX graph. Thus, we delete the useless input here.
"""
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Remove Constant from arguments.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args and kwargs.
Raises:
ValueError: If `model_kwargs` is not empty.
"""
assert not model_kwargs
return (
tuple(
arg
for arg in model_args
if not isinstance(arg, (int, float, bool, str))
),
{},
)
class FlattenInputWithTreeSpecValidationInputStep(InputAdaptStep):
"""Flatten nested collection types and return a flat list of elements.
ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
etc).
This class stores the `SpecTree` output produced when `adapt` was called the first
time. It then validates the `SpecTree` output produced from later `adapt` calls.
"""
_spec: pytree.TreeSpec | None = None
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Flatten the model args and kwargs and validate the `SpecTree` output.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the flattened model args and kwargs. The kwargs is empty, because
they are flattened and merged into the args.
Raises:
ValueError: If the `SpecTree` output produced from the current `model_outputs`
is not identical to the `SpecTree` output produced from the first
`model_outputs` that was passed to this method.
"""
flattened_args, spec = pytree.tree_flatten((model_args, model_kwargs))
if self._spec is None:
self._spec = spec
else:
_assert_identical_pytree_spec(
self._spec,
spec,
error_message="Model inputs incompatible with the format that was exported. ",
)
return flattened_args, {}
class FlattenOutputStep(OutputAdaptStep):
"""Flatten nested collection types and return a flat list of elements.
ONNX can't represent collection types (e.g., dictionary, tuple of tuple of tensor,
etc).
NOTE: Ideally we would want to use ``FlattenOutputWithTreeSpecValidationOutputStep``, such
that `SpecTree` can be validate for new model outputs. However, this is not possible
currently because we never have access to real PyTorch model outputs during export.
Only traced outputs may be available, but they are not an accurate reflection of the
original PyTorch model outputs format as they are typically in their own unique format,
depending on the tracing strategy.
"""
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Sequence[Any]:
"""Flatten the model outputs.
Args:
model_outputs: The model outputs to flatten.
model: The PyTorch model.
Returns:
A tuple of the flattened model outputs.
"""
return pytree.tree_leaves(model_outputs)
class ConvertComplexToRealRepresentationOutputStep(OutputAdaptStep):
"""Convert complex dtype tensors to real representation tensors.
ONNX does not support complex dtype tensors. Thus, we convert complex dtype tensors
to real representation tensors (i.e., float dtype tensors with an extra dimension
representing the real and imaginary parts of the complex number).
"""
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Any:
"""Convert float tensors to complex tensors.
Args:
model_output: The model output.
model: The PyTorch model.
Returns:
A tuple of the model output.
"""
return [
torch.view_as_real(output.resolve_conj())
if isinstance(output, torch.Tensor) and torch.is_complex(output)
else output
for output in model_outputs
]
class FlattenOutputWithTreeSpecValidationOutputStep(OutputAdaptStep):
"""Same as ``FlattenOutputStep``, with additional `TreeSpec` validation.
This class stores the `SpecTree` output produced when `adapt` was called the first
time. It then validates the `SpecTree` output produced from later `adapt` calls.
"""
_spec: pytree.TreeSpec | None = None
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Sequence[Any]:
"""Flatten the model outputs and validate the `SpecTree` output.
Args:
model_outputs: The model outputs to flatten.
model: The PyTorch model.
Returns:
flattened_outputs: The flattened model outputs.
Raises:
ValueError: If the `SpecTree` output produced from the current `model_outputs`
is not identical to the `SpecTree` output produced from the first
`model_outputs` that was passed to this method.
"""
flattened_outputs, spec = pytree.tree_flatten(model_outputs)
if self._spec is None:
self._spec = spec
else:
_assert_identical_pytree_spec(
self._spec,
spec,
error_message="Model outputs incompatible with the format that was exported. ",
)
return flattened_outputs
class PrependParamsBuffersConstantAotAutogradInputStep(InputAdaptStep):
"""Prepend model parameters, buffers and constants to the user input.
:func:`torch.export.export` lifts model parameters, buffers and constants as model input, thus, they
must be added to the user input before the model is executed.
Args:
model: The PyTorch model with embedded parameters and buffers.
"""
def apply(
self,
model_args: Sequence[Any],
model_kwargs: Mapping[str, Any],
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> tuple[Sequence[Any], Mapping[str, Any]]:
"""Convert complex tensors to float tensors.
Args:
model_args: The model args.
model_kwargs: The model kwargs.
model: The PyTorch model.
Returns:
A tuple of the model args and kwargs.
"""
ordered_params = tuple(
model.state_dict[name] # type: ignore[union-attr,index]
for name in model.graph_signature.parameters # type: ignore[union-attr]
)
non_persistent_buffers = set(model.graph_signature.non_persistent_buffers) # type: ignore[arg-type, union-attr]
ordered_buffers = []
for name in model.graph_signature.buffers: # type: ignore[union-attr]
if name in non_persistent_buffers:
ordered_buffers.append(model.constants[name]) # type: ignore[index, union-attr]
else:
ordered_buffers.append(model.state_dict[name]) # type: ignore[union-attr,index]
ordered_constant_tensors = tuple(
model.constants[fqn] # type: ignore[union-attr,index]
for fqn in model.graph_signature.lifted_tensor_constants # type: ignore[union-attr]
)
# NOTE: calling convention is first params, then buffers, then args as user supplied them.
# See: torch/_functorch/aot_autograd.py#L1034
updated_args = (
*ordered_params,
*ordered_buffers,
*ordered_constant_tensors,
*model_args,
)
if model_kwargs:
return MergeKwargsIntoArgsInputStep().apply(
updated_args, model_kwargs, model=model
)
return updated_args, {}
class PrependParamsAndBuffersAotAutogradOutputStep(OutputAdaptStep):
"""Prepend model's mutated buffers to the user output.
:func:`torch.export.export` lifts model's mutated buffers as outputs, thus, they
must be added to the user output after the model is executed.
Args:
model: The PyTorch model with mutated buffers.
"""
def apply(
self,
model_outputs: Any,
model: torch.nn.Module | Callable | torch_export.ExportedProgram | None = None,
) -> Sequence[Any]:
"""Flatten the model outputs and validate the `SpecTree` output.
Args:
model_outputs: The model outputs to flatten.
model: The PyTorch model.
Returns:
flattened_outputs: The flattened model outputs.
"""
assert isinstance(
model, torch_export.ExportedProgram
), "'model' must be torch_export.ExportedProgram"
ordered_buffers = tuple(
model.state_dict[name]
if name in model.state_dict
else model.constants[name]
for name in model.graph_signature.buffers_to_mutate.values()
)
# NOTE: calling convention is first mutated buffers, then outputs args as model returned them.
updated_outputs = (*ordered_buffers, *model_outputs)
return updated_outputs
|