1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874
|
# mypy: allow-untyped-defs
"""Dispatcher for AtenLib functions from onnx-script."""
from __future__ import annotations
import logging
import operator
import types
from typing import Any, Callable, Sequence, TYPE_CHECKING
import torch
import torch._ops
import torch.fx
from torch.onnx._internal.fx import (
diagnostics,
registration,
type_utils as fx_type_utils,
)
if TYPE_CHECKING:
import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
graph_building as onnxscript_graph_building,
)
from torch.onnx import OnnxRegistry
def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
fn: Callable,
self,
node: torch.fx.Node,
default_and_custom_functions: list[registration.ONNXFunction],
*args,
**kwargs,
) -> str:
"""Format the diagnostic message for the nearest match warning."""
all_function_overload_names = ""
for symbolic_func in default_and_custom_functions:
overload_func = symbolic_func.onnx_function
all_function_overload_names += f"ONNX Node: {overload_func.name}[opset={overload_func.opset};is_custom={symbolic_func.is_custom}]. \n" # noqa: B950
return f"FX Node: {node.target}. \n" f"{all_function_overload_names}"
def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
fn: Callable,
self,
node: torch.fx.Node,
*args,
**kwargs,
) -> str:
"""Format the diagnostic message for the nearest match warning."""
return f"Searching operator overload: '{node.target}' in onnx registry...\n"
class OnnxFunctionDispatcher:
"""A dispatcher that finds the best ONNX Function for ATen/Custom operators.
It uses the `torch.ops` name to find the function. If not found, it falls back to default.
Otherwise, the best match is found among all function overloads. An exact match has
higher precedence over the closest ones.
Below is a breakdown on how the dispatch mechanism works:
1. Use the torch.ops name to find the function:
a. Check if the ATen overload exists in the registry.
b. If not, check if the default overload exists in the registry.
2. Find the nearest match among all overloaded functions:
a. If the types match perfectly, select the function.
b. Otherwise, find the nearest one with the highest matching score. Because of
the potential wrongly annotated dtypes and attributes matching, we use
nearest match to find the best function once the aten name is targeted.
3. Tie-breaker: If there are multiple nearest matches, we will select the one with
the highest matching score.
NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged.
"""
def __init__(
self,
onnx_registry: OnnxRegistry,
diagnostic_context: diagnostics.DiagnosticContext,
):
"""Initialize the ONNX Function dispatcher.
Args:
onnx_registry: The ONNX registry.
diagnostic_context: The diagnostic context to use for reporting errors.
"""
self.onnx_registry = onnx_registry
self.diagnostic_context = diagnostic_context
def dispatch(
self,
node: torch.fx.Node,
onnx_args: Sequence[
fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
],
onnx_kwargs: dict[str, fx_type_utils.Argument],
diagnostic_context: diagnostics.DiagnosticContext,
) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction:
"""Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments.
Args:
node: The TorchFX node to dispatch the function for.
onnx_args: The arguments of the ONNX function.
onnx_kwargs: The keyword arguments of the ONNX function.
diagnostic_context: The diagnostic context to use for reporting errors.
Returns:
Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm.
Raises:
RuntimeError: If there are no overloaded functions available for the given FX node.
"""
# If there are no overloaded functions available for the given FX node, raise an
# unsupported error
default_and_custom_functions = self.get_function_overloads(
node, diagnostic_context
)
# If there are overloaded functions available, we will find one that perfect or
# nearest matches the given arguments and keyword arguments
return self._find_the_perfect_or_nearest_match_onnxfunction(
node,
default_and_custom_functions,
onnx_args,
onnx_kwargs,
diagnostic_context,
)
def _filter_or_keep_complex(
self,
node,
default_and_custom_functions: list[registration.ONNXFunction],
diagnostic_context: diagnostics.DiagnosticContext,
) -> list[registration.ONNXFunction]:
"""Filter the complex functions if the input has complex dtype."""
args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args]
if any(args_with_complex_dtype):
default_and_custom_functions = [
func for func in default_and_custom_functions if func.is_complex
]
# If we can't find the complex function group, raise error.
if not default_and_custom_functions:
op_full_name = self._get_aten_name(
node, diagnostic_context
).qualified_name()
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Cannot find any COMPLEX symbolic function for {op_full_name}, "
f"which should be registered under {node.target}.",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
else:
default_and_custom_functions = [
func for func in default_and_custom_functions if not func.is_complex
]
# If we can't find the complex function group, raise error.
if not default_and_custom_functions:
op_full_name = self._get_aten_name(
node, diagnostic_context
).qualified_name()
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Can ONLY find COMPLEX symbolic function for {op_full_name}, "
f"which should be registered under {node.target}.",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
return default_and_custom_functions
@diagnostics.diagnose_call(
diagnostics.rules.find_opschema_matched_symbolic_function,
diagnostic_message_formatter=_find_opschema_matched_symbolic_function_disagnostic_message_formatter,
)
def _find_the_perfect_or_nearest_match_onnxfunction(
self,
node: torch.fx.Node, # this is used in diagnostic_message_formatter
default_and_custom_functions: list[registration.ONNXFunction],
onnx_args: Sequence[
fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
],
onnx_kwargs: dict[str, fx_type_utils.Argument],
diagnostic_context: diagnostics.DiagnosticContext,
):
"""Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments.
Args:
default_and_custom_functions: The list includes overloaded functions, with
custom ones appearing after the default ones.
onnx_args: Arguments organized in PyTorch inputs way.
onnx_kwargs: Keyword arguments organized in PyTorch inputs way.
diagnostic_context: The diagnostic context to use for reporting errors.
Returns:
Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm.
Raises:
RuntimeError: If there are no overloaded functions available for the given FX node.
"""
overload_match_ranking: dict[registration.ONNXFunction, int | None] = {}
diagnostic = diagnostic_context.inflight_diagnostic()
# Iterate the overloaded functions in reverse order to prioritize the custom ones
# over the default ones, and find the perfect match.
for symbolic_function in reversed(default_and_custom_functions):
function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function)
# NOTE: 1. If the perfect match is found, return the function
if function_opschema.perfect_match_inputs(
diagnostic, onnx_args, onnx_kwargs
):
return symbolic_function.onnx_function
# Record the match score for the nearest match if it's not the perfect match
overload_match_ranking[symbolic_function] = function_opschema.match_score
# NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates
# If there is no nearest match, raise an error
overload_match_ranking = {
k: v for k, v in overload_match_ranking.items() if v is not None
}
if not overload_match_ranking:
# If there are no overloaded functions available for the given FX node, raise an
# unsupported error
op_full_name = self._get_aten_name(
node, diagnostic_context
).qualified_name()
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Cannot find any perfect/nearest match of symbolic function for {op_full_name},"
f"which should be registered under {node.target}.",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
diagnostic.warning(
"### Exact match is not found!\n"
"Cannot find a perfect match of symbolic overload, "
"a nearest match is found. Please check the ONNX output carefully. \n",
)
diagnostic.level = diagnostics.levels.WARNING
# NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one
# that is custom first. If there are multiple custom ones, we will choose the one
# that is added lastly in the list.
symbolic_function_list: list[registration.ONNXFunction] = sorted(
overload_match_ranking,
key=lambda k: (
overload_match_ranking[k],
k.is_custom,
default_and_custom_functions.index(k),
),
reverse=True,
)
return symbolic_function_list[0].onnx_function
def _get_aten_name(
self, node: torch.fx.Node, diagnostic_context: diagnostics.DiagnosticContext
) -> registration.OpName:
"""Get the OpName from the target.
Args:
node: The TorchFX node to get the aten name for.
diagnostic_context: The diagnostic context to use for reporting errors.
Returns:
The internal op name within dataclass: registration.OpName.
"""
if node.target == operator.getitem:
return registration.OpName.from_name_parts(
namespace="aten", op_name="getitem"
)
if isinstance(node.target, torch._ops.OpOverloadPacket):
# aten::sym_size is the only OverloadPacket that we support.
# schema: aten::sym_size(Tensor self, int dim) -> Tensor
if node.target != torch.ops.aten.sym_size:
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
# TODO(titaiwang): aten::sym_size has overload, but fx graph is using
# overloadpacket for some reasons.
# https://github.com/pytorch/pytorch/issues/97201
aten_op_default = node.target.default
return registration.OpName.from_op_overload(op_overload=aten_op_default) # type: ignore[no-any-return]
if isinstance(node.target, types.BuiltinFunctionType):
# Make sure it's symint/symfloat consuming builtin ops.
for node_arg in node.args:
if (not isinstance(node_arg, (torch.fx.Node, int, float))) or (
isinstance(node_arg, torch.fx.Node)
and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"])
):
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target},"
" only int/float/SymInt/SymFloat is supported with built-in ops!",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
return registration.OpName.from_builtin_function(node.target)
if isinstance(node.target, torch._ops.OpOverload):
return registration.OpName.from_op_overload(op_overload=node.target)
# Unexpected target, raise error.
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Unknown call_function target: {node.target}",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
@diagnostics.diagnose_call(
diagnostics.rules.find_operator_overloads_in_onnx_registry,
diagnostic_message_formatter=_find_operator_overloads_in_onnx_registry_disagnostic_message_formatter,
)
def get_function_overloads(
self,
node: torch.fx.Node,
diagnostic_context: diagnostics.DiagnosticContext,
) -> list[registration.ONNXFunction]:
"""Get the function overloads from the registry.
Args:
node: The node to get the function overloads for.
diagnostic_context: The diagnostic context to use for reporting errors.
Returns:
The list contains ONNXFunctions, starting with the default ones and
followed by any custom ones.
"""
internal_opname: registration.OpName = self._get_aten_name(
node=node, diagnostic_context=diagnostic_context
)
# If the ATen/Custom operators are not registered, the group will be None.
# And non-registered ATen/Custom operators will trigger error in the next step.
function_group: list[registration.ONNXFunction] | None = None
function_group = self.onnx_registry.get_op_functions(
namespace=internal_opname.namespace,
op_name=internal_opname.op_name,
overload=internal_opname.overload,
)
# NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
if function_group is None:
function_group = self.onnx_registry.get_op_functions(
namespace=internal_opname.namespace,
op_name=internal_opname.op_name,
overload=None,
)
if function_group is not None:
op_full_name = internal_opname.qualified_name()
diagnostic = diagnostic_context.inflight_diagnostic()
diagnostic.warning(
"### The operator overload is not found in onnx registry!\n"
"Cannot find the operator overload in onnx registry, but "
"the default overload is found. Please check the ONNX output carefully. \n",
)
diagnostic.level = diagnostics.levels.WARNING
if function_group is not None:
# NOTE: If the input has complex dtype, we will only dispatch to the complex functions.
function_group = self._filter_or_keep_complex(
node, function_group, diagnostic_context
)
return function_group # type: ignore[return-value]
op_full_name = internal_opname.qualified_name()
diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
diagnostics.rules.no_symbolic_function_for_call_function,
diagnostics.levels.ERROR,
f"Cannot find symbolic function for {op_full_name}, "
f"which should be registered under {node.target}.",
unsupported_fx_node=node,
)
diagnostic_context.log(diagnostic)
raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
class _OnnxSchemaChecker:
"""
The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema.
It provides methods to check for input compatibility based on the OpSchema. It also
provides a matching score to indicate how well the OpSchema matches the input and
kwargs types. A function will be evaluated as perfect match, nearest match eligible,
or no match.
Here are some common examples in categories:
1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as
the OpSchema. The types of inputs and attributes are exactly the same as the
OpSchema.
```python
inputs = (Tensor[2, 3], Tensor[2, 3])
attributes = {"alpha": 1.0}
@torch_op("aten::op")
def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ...
```
Result: Perfect match.
2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However,
the input can't be ignored. None must be provided.
```python
inputs = (Tensor([2, 3]), None)
attributes = {}
aten_op(X: TTensor, Y: Optional[INT64]):
...
```
Result: Perfect match.
Real example: `aten::convolution`.
3. [NOTE: Different attributes]: If an attribute is provided with value, it's
a must to match the attribute in function signature.
```python
inputs = (Tensor([2, 3]),)
attributes = {"a":1, "b":2}
aten_op(X: TTensor, a: int):
...
```
Result: No match.
Real example: `aten::div` vs `aten::div.Tensor_mode`.
4. [NOTE: Default attributes]: Default attribute will fill in the value into
inputs/attributes.
```python
inputs = (Tensor([2, 3]),)
attributes = {}
aten_op(X: TTensor, a: int = 3):
...
```
Result: Perfect match.
Real example: `aten::clone`
5. [NOTE: Ignore attribute with None value]: The attributes with None value
will be ignored in matching.
```python
inputs = (Tensor([2, 3]),)
attributes = {"a": None}
aten_op(X: TTensor):
...
```
Result: Perfect match.
```python
inputs = (Tensor([2, 3]),)
attributes = {"a": None}
aten_op(X: TTensor, a: int = 3):
...
```
Result: Nearest match eligible.
Real example: `aten::div` vs `aten::div.Tensor_mode`.
Attributes:
onnxfunction: The OnnxFunction.
param_schema: The parameter schema defined in the OnnxFunction.
op_schema: The ONNX OpSchema.
type_constraints: The type constraints defined in the OpSchema.
attributes: The attributes defined in the OpSchema.
_matching_score: The matching score of the OnnxSchemaChecker .
"""
def __init__(
self,
onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
):
"""Initialize the OnnxSchemaChecker .
Args:
onnxfunction: The OnnxFunction.
"""
self.onnxfunction = onnxfunction
self.param_schema = self.onnxfunction.param_schemas()
op_schema = self.onnxfunction.op_schema
# Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`.
# However their base class would. Hence return type is annotated as Optional[OpSchema].
assert op_schema is not None
self.op_schema = op_schema
self.type_constraints = {
# "T": {"tensor(int64)"}
constraint.type_param_str: set(constraint.allowed_type_strs)
for constraint in self.op_schema.type_constraints
}
self.attributes = self.op_schema.attributes
self._matching_score: int | None = None
@property
def match_score(self) -> int | None:
"""The matching score of the OnnxSchemaChecker .
If this remains None, it means the matching score has not been calculated,
and it's not a nearest match candidate.
Returns:
The matching score of the OnnxSchemaChecker .
"""
return self._matching_score
def perfect_match_inputs(
self,
diagnostic: diagnostics.Diagnostic,
args: Sequence[
fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
],
kwargs: dict[str, fx_type_utils.Argument],
) -> bool:
"""Check if the inputs perfectly match the OpSchema requirements.
The definition of perfect match is that the input types are all in the type
constraints and the number of inputs matches the number of inputs in the
OpSchema.
Checking steps:
1. The function signature matches the inputs number, and attribute names.
2. The input/attribute types are all in the type constraints.
A function should at least pass the first step to be eligible for the
nearest matching.
Args:
diagnostic: The diagnostic to use for logging detailed info.
args: The input arguments organized in PyTorch inputs way.
kwargs: The input keyword arguments organized in PyTorch inputs way.
Returns:
True if the inputs match the requirements, False otherwise.
"""
# NOTE: OnnxFunction does not have the same function signature as the original
# PyTorch operator. We need to separate the input/attributes from the arguments.
(
function_inputs,
function_attributes,
) = self._separate_input_attributes_from_arguments(
self.param_schema,
args,
kwargs,
fill_defaults=True, # fill defaults for optional arguments to match
)
with diagnostic.log_section(logging.INFO, "Checking perfect match..."):
diagnostic.info(
"%s",
diagnostics.LazyString(diagnostics.format_argument, self.onnxfunction),
)
# NOTE: 1. Check if the input number and attribute names match the
# OpSchema. If it's not, we know the function is not eligible to be a perfect
# match, nor a nearest match.
# We use is_perfect_match to postpone the return value to the end
# of the function, as we want to log all the mismatch info.
is_perfect_match = True
if len(function_inputs) != len(self.op_schema.inputs):
with diagnostic.log_section(
logging.INFO, "Failed: input number mismatch!"
):
diagnostic.info(
"Actual %d vs expected %d",
len(function_inputs),
len(self.op_schema.inputs),
)
diagnostic.info("The function is not a nearest match candidate.")
is_perfect_match = False
if set(function_attributes) != set(self.attributes):
with diagnostic.log_section(
logging.INFO, "Failed: attribute mismatch!"
):
diagnostic.info(
"%s",
diagnostics.LazyString(
lambda: f"Actual {set(function_attributes)} vs expected {set(self.attributes)}",
),
)
diagnostic.info("The function is not a nearest match candidate.")
is_perfect_match = False
# If it's already not a perfect match, we can return False directly. Further
# checking is only for the functions that are eligible for nearest match.
if not is_perfect_match:
return False
# NOTE: 2. The dtypes of inputs and attributes should be in the
# type constraints of the OpSchema. If they are not, we know the function is not
# eligible to be a perfect match, but can be a nearest match candidate.
for schema_input, torch_input in zip(
self.op_schema.inputs, function_inputs
):
torch_input_compatible_types = _find_onnx_data_type(torch_input)
allowed_types = self.type_constraints[schema_input.type_str]
if not allowed_types.intersection(
torch_input_compatible_types
) and not any(
fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str)
for onnx_type_str in allowed_types
):
# If torch_input_compatible_types isn't in allowed_types
# of this input defined in the OpSchema, we know the function
# and the input are not compatible
with diagnostic.log_section(
logging.INFO,
"Failed: input type mismatch for input '%s'!",
schema_input.name,
):
diagnostic.info(
"Actual %s vs\nExpected %s",
torch_input_compatible_types,
allowed_types,
)
is_perfect_match = False
for attribute_name, attribute in function_attributes.items():
if not self._match_onnx_attribute_type(attribute_name, attribute):
# If the attribute type of the OpSchema and the attribute type don't match,
# we know the function and the input are not compatible
with diagnostic.log_section(
logging.INFO,
"Failed: attribute '%s' type mismatch!",
attribute_name,
):
diagnostic.info(
"Actual %s vs\nExpected %s",
type(attribute),
self.attributes[attribute_name].type,
)
is_perfect_match = False
# NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype.
self._record_matching_score(function_inputs, function_attributes)
diagnostic.info("match score: %d", self.match_score)
return is_perfect_match
def _match_onnx_attribute_type(
self,
attribute_name: str,
attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor,
is_sequence: bool = False,
) -> bool:
if isinstance(attribute, (int, float, bool, str)):
attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
type(attribute), is_sequence=is_sequence
)
if attribute_onnx_type != self.attributes[attribute_name].type:
return False
# If the attribute is an empty list, we don't know the type of the list
# so it's a mismatch
elif isinstance(attribute, (list, tuple)) and attribute:
return self._match_onnx_attribute_type(
attribute_name, attribute[0], is_sequence=True
)
else:
# NOTE: Unrecognized attribute type
return False
return True
def _record_matching_score(
self,
inputs: Sequence[
fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
],
attributes: dict[str, fx_type_utils.Argument],
):
"""Calculate the inputs matching score of the OpSchema requirements to find the nearest match.
Only the functions which have the same number of inputs and attributes as the
OpSchema are eligible to be a nearest match candidate. Thus, we don't need to
check the length of inputs and attributes here, and only check the types of
inputs and attributes.
How the matchsing score is calculated:
score += 1 if one input/attribute type is in the type constraints.
Limitations:
None/NoeType/[] could result in zero matches, and the same score of overloads,
which will be recorded in SARIF.
Args:
inputs: The input arguments.
attributes: The input keyword arguments.
Returns:
True if the inputs match the requirements, False otherwise.
"""
self._matching_score = 0
# If they have different length of arguments, the score would be lower to those
# functions which have the same length of arguments.
for schema_input, torch_input in zip(self.op_schema.inputs, inputs):
torch_input_compatible_types = _find_onnx_data_type(torch_input)
allowed_types = self.type_constraints[schema_input.type_str]
if allowed_types.intersection(torch_input_compatible_types):
# If torch_input_compatible_types is in allowed_types
# of this input defined in the OpSchema, we know the function
# and the input are compatible
self._matching_score += 1
# NOTE: The penalty is applied to those functions which have different attributes.
for attribute_name, attribute_proto in self.attributes.items():
attribute = attributes[attribute_name]
attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
type(attribute)
)
if attribute_onnx_type != attribute_proto.type:
# If the attribute type of the OpSchema and the attribute type don't match,
# we know the function and the input are not compatible
self._matching_score -= 1
# NOTE: Referenced from onnxscript internal function.
# Importing this function makes the code less robust, as it is not a public API.
def _separate_input_attributes_from_arguments(
self,
param_schemas: Sequence[onnxscript.values.ParamSchema],
args: Sequence[
fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
],
kwargs: dict[str, fx_type_utils.Argument],
fill_defaults: bool = True,
) -> tuple[list[Any], dict[str, Any]]:
"""Separate Python args and kwargs into ONNX inputs and attributes.
Extra_kwargs are ignored if their values are None. For example, if the
OpSchema has an attribute "rounding_mode" and the caller provides
"rounding_mode=None", the attribute "rounding_mode" will not be included
in the returned attributes when the OnnxFunction signature doesn't have
"rounding_mode" as an attribute.
Args:
param_schemas: The parameter schemas of an Op or a OnnxFunction.
args: The Python positional arguments supplied by the caller.
kwargs: The Python keyword arguments supplied by the caller.
fill_defaults: Whether to fill the default values for attributes.
Returns:
A tuple of two elements:
- A list of ONNX inputs.
- An dictionary of ONNX attribute names and values.
Raises:
TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
TypeError: When a required input is not provided.
"""
# args, kwargs and param_schemas should be all in order
# user may not specify all inputs or attributes
import onnx
onnx_inputs: list[Any] = []
onnx_attributes: dict[str, Any] = {}
# NOTE: We need to copy kwargs because we will mutate it
copy_kwargs = kwargs.copy()
for i, param in enumerate(param_schemas):
if param.is_variadic_input:
# Exhaust all remaining args
onnx_inputs.extend(args[i:])
args = []
continue
if i < len(args):
if param.is_input:
onnx_inputs.append(args[i])
else:
onnx_attributes[param.name] = args[i]
elif param.name in copy_kwargs:
if param.is_input:
# Move the input from kwargs to inputs
onnx_inputs.append(copy_kwargs[param.name])
copy_kwargs.pop(param.name)
else:
onnx_attributes[param.name] = copy_kwargs[param.name]
elif (
param.is_attribute
and self.attributes[param.name].default_value.type
!= onnx.AttributeProto.UNDEFINED # type: ignore[attr-defined]
):
# User did not provide the attribute
if fill_defaults:
onnx_attributes[param.name] = param.default
# optional input
elif param.is_input:
if fill_defaults:
onnx_inputs.append(None)
# NOTE: Pick up extra kwargs if it's not None. None is not expected
# as an attribute value in torchlib.
for k, v in copy_kwargs.items():
if k not in onnx_attributes and v is not None:
onnx_attributes[k] = v
return onnx_inputs, onnx_attributes
def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
"""Check if the node has complex dtype recursively."""
if (
isinstance(arg, torch.fx.Node)
and "val" in arg.meta
and isinstance(arg.meta["val"], torch.Tensor)
and torch.is_complex(arg.meta["val"])
):
return True
elif isinstance(arg, list):
for item in arg:
return _is_arg_with_complex_dtype(item)
return False
def _find_onnx_data_type(
torch_input: fx_type_utils.TensorLike
| str
| int
| float
| bool
| list
| tuple
| complex
| None,
) -> set[str]:
"""Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string."""
if (
isinstance(torch_input, fx_type_utils.TensorLike)
and torch_input.dtype is not None
):
return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype)
if isinstance(torch_input, (int, float, bool, str, complex)):
return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input))
if isinstance(torch_input, (list, tuple)) and torch_input: # [Tensor, Tensor]
the_first_non_none_item = next(
(item for item in torch_input if item is not None), None
)
set_dtype = _find_onnx_data_type(the_first_non_none_item)
if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input):
# NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type))
return {f"seq({dtype})" for dtype in set_dtype}
else:
# constant list of non-tensor type
return set_dtype
if (
torch_input is None
or (
isinstance(torch_input, fx_type_utils.TensorLike)
and torch_input.dtype is None
)
or (isinstance(torch_input, (list, tuple)) and not torch_input)
):
# NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check
# seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case.
return set()
raise RuntimeError(f"Unknown input type from input: {torch_input}")
|