File: onnxfunction_dispatcher.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (874 lines) | stat: -rw-r--r-- 37,770 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
# mypy: allow-untyped-defs
"""Dispatcher for AtenLib functions from onnx-script."""

from __future__ import annotations

import logging
import operator
import types
from typing import Any, Callable, Sequence, TYPE_CHECKING

import torch
import torch._ops
import torch.fx
from torch.onnx._internal.fx import (
    diagnostics,
    registration,
    type_utils as fx_type_utils,
)


if TYPE_CHECKING:
    import onnxscript  # type: ignore[import]
    from onnxscript.function_libs.torch_lib import (  # type: ignore[import]
        graph_building as onnxscript_graph_building,
    )

    from torch.onnx import OnnxRegistry


def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
    fn: Callable,
    self,
    node: torch.fx.Node,
    default_and_custom_functions: list[registration.ONNXFunction],
    *args,
    **kwargs,
) -> str:
    """Format the diagnostic message for the nearest match warning."""
    all_function_overload_names = ""
    for symbolic_func in default_and_custom_functions:
        overload_func = symbolic_func.onnx_function
        all_function_overload_names += f"ONNX Node: {overload_func.name}[opset={overload_func.opset};is_custom={symbolic_func.is_custom}]. \n"  # noqa: B950
    return f"FX Node: {node.target}. \n" f"{all_function_overload_names}"


def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
    fn: Callable,
    self,
    node: torch.fx.Node,
    *args,
    **kwargs,
) -> str:
    """Format the diagnostic message for the nearest match warning."""
    return f"Searching operator overload: '{node.target}' in onnx registry...\n"


class OnnxFunctionDispatcher:
    """A dispatcher that finds the best ONNX Function for ATen/Custom operators.

    It uses the `torch.ops` name to find the function. If not found, it falls back to default.
    Otherwise, the best match is found among all function overloads. An exact match has
    higher precedence over the closest ones.

    Below is a breakdown on how the dispatch mechanism works:

    1. Use the torch.ops name to find the function:
        a. Check if the ATen overload exists in the registry.
        b. If not, check if the default overload exists in the registry.

    2. Find the nearest match among all overloaded functions:
        a. If the types match perfectly, select the function.
        b. Otherwise, find the nearest one with the highest matching score. Because of
            the potential wrongly annotated dtypes and attributes matching, we use
            nearest match to find the best function once the aten name is targeted.

    3. Tie-breaker: If there are multiple nearest matches, we will select the one with
        the highest matching score.

    NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged.
    """

    def __init__(
        self,
        onnx_registry: OnnxRegistry,
        diagnostic_context: diagnostics.DiagnosticContext,
    ):
        """Initialize the ONNX Function dispatcher.

        Args:
            onnx_registry: The ONNX registry.
            diagnostic_context: The diagnostic context to use for reporting errors.
        """
        self.onnx_registry = onnx_registry
        self.diagnostic_context = diagnostic_context

    def dispatch(
        self,
        node: torch.fx.Node,
        onnx_args: Sequence[
            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
        ],
        onnx_kwargs: dict[str, fx_type_utils.Argument],
        diagnostic_context: diagnostics.DiagnosticContext,
    ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction:
        """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments.
        Args:
            node: The TorchFX node to dispatch the function for.
            onnx_args: The arguments of the ONNX function.
            onnx_kwargs: The keyword arguments of the ONNX function.
            diagnostic_context: The diagnostic context to use for reporting errors.
        Returns:
            Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm.
        Raises:
            RuntimeError: If there are no overloaded functions available for the given FX node.
        """
        # If there are no overloaded functions available for the given FX node, raise an
        # unsupported error
        default_and_custom_functions = self.get_function_overloads(
            node, diagnostic_context
        )

        # If there are overloaded functions available, we will find one that perfect or
        # nearest matches the given arguments and keyword arguments
        return self._find_the_perfect_or_nearest_match_onnxfunction(
            node,
            default_and_custom_functions,
            onnx_args,
            onnx_kwargs,
            diagnostic_context,
        )

    def _filter_or_keep_complex(
        self,
        node,
        default_and_custom_functions: list[registration.ONNXFunction],
        diagnostic_context: diagnostics.DiagnosticContext,
    ) -> list[registration.ONNXFunction]:
        """Filter the complex functions if the input has complex dtype."""

        args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args]
        if any(args_with_complex_dtype):
            default_and_custom_functions = [
                func for func in default_and_custom_functions if func.is_complex
            ]
            # If we can't find the complex function group, raise error.
            if not default_and_custom_functions:
                op_full_name = self._get_aten_name(
                    node, diagnostic_context
                ).qualified_name()
                diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
                    diagnostics.rules.no_symbolic_function_for_call_function,
                    diagnostics.levels.ERROR,
                    f"Cannot find any COMPLEX symbolic function for {op_full_name}, "
                    f"which should be registered under {node.target}.",
                    unsupported_fx_node=node,
                )
                diagnostic_context.log(diagnostic)
                raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
        else:
            default_and_custom_functions = [
                func for func in default_and_custom_functions if not func.is_complex
            ]
            # If we can't find the complex function group, raise error.
            if not default_and_custom_functions:
                op_full_name = self._get_aten_name(
                    node, diagnostic_context
                ).qualified_name()
                diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
                    diagnostics.rules.no_symbolic_function_for_call_function,
                    diagnostics.levels.ERROR,
                    f"Can ONLY find COMPLEX symbolic function for {op_full_name}, "
                    f"which should be registered under {node.target}.",
                    unsupported_fx_node=node,
                )
                diagnostic_context.log(diagnostic)
                raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
        return default_and_custom_functions

    @diagnostics.diagnose_call(
        diagnostics.rules.find_opschema_matched_symbolic_function,
        diagnostic_message_formatter=_find_opschema_matched_symbolic_function_disagnostic_message_formatter,
    )
    def _find_the_perfect_or_nearest_match_onnxfunction(
        self,
        node: torch.fx.Node,  # this is used in diagnostic_message_formatter
        default_and_custom_functions: list[registration.ONNXFunction],
        onnx_args: Sequence[
            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
        ],
        onnx_kwargs: dict[str, fx_type_utils.Argument],
        diagnostic_context: diagnostics.DiagnosticContext,
    ):
        """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments.

        Args:
            default_and_custom_functions: The list includes overloaded functions, with
                custom ones appearing after the default ones.
            onnx_args: Arguments organized in PyTorch inputs way.
            onnx_kwargs: Keyword arguments organized in PyTorch inputs way.
            diagnostic_context: The diagnostic context to use for reporting errors.

            Returns:
                Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm.
            Raises:
                RuntimeError: If there are no overloaded functions available for the given FX node.
        """
        overload_match_ranking: dict[registration.ONNXFunction, int | None] = {}
        diagnostic = diagnostic_context.inflight_diagnostic()

        # Iterate the overloaded functions in reverse order to prioritize the custom ones
        # over the default ones, and find the perfect match.
        for symbolic_function in reversed(default_and_custom_functions):
            function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function)

            # NOTE: 1. If the perfect match is found, return the function
            if function_opschema.perfect_match_inputs(
                diagnostic, onnx_args, onnx_kwargs
            ):
                return symbolic_function.onnx_function
            # Record the match score for the nearest match if it's not the perfect match
            overload_match_ranking[symbolic_function] = function_opschema.match_score

        # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates
        # If there is no nearest match, raise an error
        overload_match_ranking = {
            k: v for k, v in overload_match_ranking.items() if v is not None
        }
        if not overload_match_ranking:
            # If there are no overloaded functions available for the given FX node, raise an
            # unsupported error
            op_full_name = self._get_aten_name(
                node, diagnostic_context
            ).qualified_name()
            diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
                diagnostics.rules.no_symbolic_function_for_call_function,
                diagnostics.levels.ERROR,
                f"Cannot find any perfect/nearest match of symbolic function for {op_full_name},"
                f"which should be registered under {node.target}.",
                unsupported_fx_node=node,
            )
            diagnostic_context.log(diagnostic)
            raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)

        diagnostic.warning(
            "### Exact match is not found!\n"
            "Cannot find a perfect match of symbolic overload, "
            "a nearest match is found. Please check the ONNX output carefully. \n",
        )
        diagnostic.level = diagnostics.levels.WARNING
        # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one
        # that is custom first. If there are multiple custom ones, we will choose the one
        # that is added lastly in the list.
        symbolic_function_list: list[registration.ONNXFunction] = sorted(
            overload_match_ranking,
            key=lambda k: (
                overload_match_ranking[k],
                k.is_custom,
                default_and_custom_functions.index(k),
            ),
            reverse=True,
        )
        return symbolic_function_list[0].onnx_function

    def _get_aten_name(
        self, node: torch.fx.Node, diagnostic_context: diagnostics.DiagnosticContext
    ) -> registration.OpName:
        """Get the OpName from the target.

        Args:
            node: The TorchFX node to get the aten name for.
            diagnostic_context: The diagnostic context to use for reporting errors.

        Returns:
            The internal op name within dataclass: registration.OpName.
        """
        if node.target == operator.getitem:
            return registration.OpName.from_name_parts(
                namespace="aten", op_name="getitem"
            )
        if isinstance(node.target, torch._ops.OpOverloadPacket):
            # aten::sym_size is the only OverloadPacket that we support.
            # schema: aten::sym_size(Tensor self, int dim) -> Tensor
            if node.target != torch.ops.aten.sym_size:
                diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
                    diagnostics.rules.no_symbolic_function_for_call_function,
                    diagnostics.levels.ERROR,
                    f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!",
                    unsupported_fx_node=node,
                )
                diagnostic_context.log(diagnostic)
                raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
            # TODO(titaiwang): aten::sym_size has overload, but fx graph is using
            # overloadpacket for some reasons.
            # https://github.com/pytorch/pytorch/issues/97201
            aten_op_default = node.target.default
            return registration.OpName.from_op_overload(op_overload=aten_op_default)  # type: ignore[no-any-return]

        if isinstance(node.target, types.BuiltinFunctionType):
            # Make sure it's symint/symfloat consuming builtin ops.
            for node_arg in node.args:
                if (not isinstance(node_arg, (torch.fx.Node, int, float))) or (
                    isinstance(node_arg, torch.fx.Node)
                    and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"])
                ):
                    diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
                        diagnostics.rules.no_symbolic_function_for_call_function,
                        diagnostics.levels.ERROR,
                        f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target},"
                        " only int/float/SymInt/SymFloat is supported with built-in ops!",
                        unsupported_fx_node=node,
                    )
                    diagnostic_context.log(diagnostic)
                    raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
            return registration.OpName.from_builtin_function(node.target)

        if isinstance(node.target, torch._ops.OpOverload):
            return registration.OpName.from_op_overload(op_overload=node.target)

        # Unexpected target, raise error.
        diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
            diagnostics.rules.no_symbolic_function_for_call_function,
            diagnostics.levels.ERROR,
            f"Unknown call_function target: {node.target}",
            unsupported_fx_node=node,
        )
        diagnostic_context.log(diagnostic)
        raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)

    @diagnostics.diagnose_call(
        diagnostics.rules.find_operator_overloads_in_onnx_registry,
        diagnostic_message_formatter=_find_operator_overloads_in_onnx_registry_disagnostic_message_formatter,
    )
    def get_function_overloads(
        self,
        node: torch.fx.Node,
        diagnostic_context: diagnostics.DiagnosticContext,
    ) -> list[registration.ONNXFunction]:
        """Get the function overloads from the registry.

        Args:
            node: The node to get the function overloads for.
            diagnostic_context: The diagnostic context to use for reporting errors.

        Returns:
            The list contains ONNXFunctions, starting with the default ones and
            followed by any custom ones.
        """

        internal_opname: registration.OpName = self._get_aten_name(
            node=node, diagnostic_context=diagnostic_context
        )

        # If the ATen/Custom operators are not registered, the group will be None.
        # And non-registered ATen/Custom operators will trigger error in the next step.
        function_group: list[registration.ONNXFunction] | None = None

        function_group = self.onnx_registry.get_op_functions(
            namespace=internal_opname.namespace,
            op_name=internal_opname.op_name,
            overload=internal_opname.overload,
        )

        # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
        if function_group is None:
            function_group = self.onnx_registry.get_op_functions(
                namespace=internal_opname.namespace,
                op_name=internal_opname.op_name,
                overload=None,
            )
            if function_group is not None:
                op_full_name = internal_opname.qualified_name()
                diagnostic = diagnostic_context.inflight_diagnostic()
                diagnostic.warning(
                    "### The operator overload is not found in onnx registry!\n"
                    "Cannot find the operator overload in onnx registry, but "
                    "the default overload is found. Please check the ONNX output carefully. \n",
                )
                diagnostic.level = diagnostics.levels.WARNING

        if function_group is not None:
            # NOTE: If the input has complex dtype, we will only dispatch to the complex functions.
            function_group = self._filter_or_keep_complex(
                node, function_group, diagnostic_context
            )
            return function_group  # type: ignore[return-value]

        op_full_name = internal_opname.qualified_name()
        diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
            diagnostics.rules.no_symbolic_function_for_call_function,
            diagnostics.levels.ERROR,
            f"Cannot find symbolic function for {op_full_name}, "
            f"which should be registered under {node.target}.",
            unsupported_fx_node=node,
        )
        diagnostic_context.log(diagnostic)
        raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)


class _OnnxSchemaChecker:
    """
    The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema.

    It provides methods to check for input compatibility based on the OpSchema. It also
    provides a matching score to indicate how well the OpSchema matches the input and
    kwargs types. A function will be evaluated as perfect match, nearest match eligible,
    or no match.

    Here are some common examples in categories:

    1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as
        the OpSchema. The types of inputs and attributes are exactly the same as the
        OpSchema.

        ```python
        inputs = (Tensor[2, 3], Tensor[2, 3])
        attributes = {"alpha": 1.0}


        @torch_op("aten::op")
        def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ...
        ```
        Result: Perfect match.

    2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However,
        the input can't be ignored. None must be provided.

        ```python
        inputs = (Tensor([2, 3]), None)
        attributes = {}

        aten_op(X: TTensor, Y: Optional[INT64]):
            ...
        ```
        Result: Perfect match.
        Real example: `aten::convolution`.

    3. [NOTE: Different attributes]: If an attribute is provided with value, it's
        a must to match the attribute in function signature.
        ```python
        inputs = (Tensor([2, 3]),)
        attributes = {"a":1, "b":2}

        aten_op(X: TTensor, a: int):
            ...
        ```
        Result: No match.
        Real example: `aten::div` vs `aten::div.Tensor_mode`.

    4. [NOTE: Default attributes]: Default attribute will fill in the value into
        inputs/attributes.
        ```python
        inputs = (Tensor([2, 3]),)
        attributes = {}

        aten_op(X: TTensor, a: int = 3):
            ...
        ```
        Result: Perfect match.
        Real example: `aten::clone`

    5. [NOTE: Ignore attribute with None value]: The attributes with None value
        will be ignored in matching.
        ```python
        inputs = (Tensor([2, 3]),)
        attributes = {"a": None}

        aten_op(X: TTensor):
            ...
        ```
        Result: Perfect match.

        ```python
        inputs = (Tensor([2, 3]),)
        attributes = {"a": None}

        aten_op(X: TTensor, a: int = 3):
            ...
        ```
        Result: Nearest match eligible.

        Real example: `aten::div` vs `aten::div.Tensor_mode`.

    Attributes:
        onnxfunction: The OnnxFunction.
        param_schema: The parameter schema defined in the OnnxFunction.
        op_schema: The ONNX OpSchema.
        type_constraints: The type constraints defined in the OpSchema.
        attributes: The attributes defined in the OpSchema.
        _matching_score: The matching score of the OnnxSchemaChecker .

    """

    def __init__(
        self,
        onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
    ):
        """Initialize the OnnxSchemaChecker .

        Args:
            onnxfunction: The OnnxFunction.
        """
        self.onnxfunction = onnxfunction
        self.param_schema = self.onnxfunction.param_schemas()
        op_schema = self.onnxfunction.op_schema
        # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`.
        # However their base class would. Hence return type is annotated as Optional[OpSchema].
        assert op_schema is not None
        self.op_schema = op_schema
        self.type_constraints = {
            # "T": {"tensor(int64)"}
            constraint.type_param_str: set(constraint.allowed_type_strs)
            for constraint in self.op_schema.type_constraints
        }
        self.attributes = self.op_schema.attributes
        self._matching_score: int | None = None

    @property
    def match_score(self) -> int | None:
        """The matching score of the OnnxSchemaChecker .

        If this remains None, it means the matching score has not been calculated,
        and it's not a nearest match candidate.

        Returns:
            The matching score of the OnnxSchemaChecker .
        """
        return self._matching_score

    def perfect_match_inputs(
        self,
        diagnostic: diagnostics.Diagnostic,
        args: Sequence[
            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
        ],
        kwargs: dict[str, fx_type_utils.Argument],
    ) -> bool:
        """Check if the inputs perfectly match the OpSchema requirements.

        The definition of perfect match is that the input types are all in the type
        constraints and the number of inputs matches the number of inputs in the
        OpSchema.

        Checking steps:
        1. The function signature matches the inputs number, and attribute names.
        2. The input/attribute types are all in the type constraints.

        A function should at least pass the first step to be eligible for the
        nearest matching.

        Args:
            diagnostic: The diagnostic to use for logging detailed info.
            args: The input arguments organized in PyTorch inputs way.
            kwargs: The input keyword arguments organized in PyTorch inputs way.

        Returns:
            True if the inputs match the requirements, False otherwise.
        """

        # NOTE: OnnxFunction does not have the same function signature as the original
        # PyTorch operator. We need to separate the input/attributes from the arguments.
        (
            function_inputs,
            function_attributes,
        ) = self._separate_input_attributes_from_arguments(
            self.param_schema,
            args,
            kwargs,
            fill_defaults=True,  # fill defaults for optional arguments to match
        )
        with diagnostic.log_section(logging.INFO, "Checking perfect match..."):
            diagnostic.info(
                "%s",
                diagnostics.LazyString(diagnostics.format_argument, self.onnxfunction),
            )
            # NOTE: 1. Check if the input number and attribute names match the
            # OpSchema. If it's not, we know the function is not eligible to be a perfect
            # match, nor a nearest match.
            # We use is_perfect_match to postpone the return value to the end
            # of the function, as we want to log all the mismatch info.
            is_perfect_match = True
            if len(function_inputs) != len(self.op_schema.inputs):
                with diagnostic.log_section(
                    logging.INFO, "Failed: input number mismatch!"
                ):
                    diagnostic.info(
                        "Actual %d vs expected %d",
                        len(function_inputs),
                        len(self.op_schema.inputs),
                    )
                diagnostic.info("The function is not a nearest match candidate.")
                is_perfect_match = False

            if set(function_attributes) != set(self.attributes):
                with diagnostic.log_section(
                    logging.INFO, "Failed: attribute mismatch!"
                ):
                    diagnostic.info(
                        "%s",
                        diagnostics.LazyString(
                            lambda: f"Actual {set(function_attributes)} vs expected {set(self.attributes)}",
                        ),
                    )
                diagnostic.info("The function is not a nearest match candidate.")
                is_perfect_match = False

            # If it's already not a perfect match, we can return False directly. Further
            # checking is only for the functions that are eligible for nearest match.
            if not is_perfect_match:
                return False

            # NOTE: 2. The dtypes of inputs and attributes should be in the
            # type constraints of the OpSchema. If they are not, we know the function is not
            # eligible to be a perfect match, but can be a nearest match candidate.
            for schema_input, torch_input in zip(
                self.op_schema.inputs, function_inputs
            ):
                torch_input_compatible_types = _find_onnx_data_type(torch_input)
                allowed_types = self.type_constraints[schema_input.type_str]
                if not allowed_types.intersection(
                    torch_input_compatible_types
                ) and not any(
                    fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str)
                    for onnx_type_str in allowed_types
                ):
                    # If torch_input_compatible_types isn't in allowed_types
                    # of this input defined in the OpSchema, we know the function
                    # and the input are not compatible
                    with diagnostic.log_section(
                        logging.INFO,
                        "Failed: input type mismatch for input '%s'!",
                        schema_input.name,
                    ):
                        diagnostic.info(
                            "Actual %s vs\nExpected %s",
                            torch_input_compatible_types,
                            allowed_types,
                        )
                    is_perfect_match = False

            for attribute_name, attribute in function_attributes.items():
                if not self._match_onnx_attribute_type(attribute_name, attribute):
                    # If the attribute type of the OpSchema and the attribute type don't match,
                    # we know the function and the input are not compatible
                    with diagnostic.log_section(
                        logging.INFO,
                        "Failed: attribute '%s' type mismatch!",
                        attribute_name,
                    ):
                        diagnostic.info(
                            "Actual %s vs\nExpected %s",
                            type(attribute),
                            self.attributes[attribute_name].type,
                        )
                    is_perfect_match = False

            # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype.
            self._record_matching_score(function_inputs, function_attributes)
            diagnostic.info("match score: %d", self.match_score)
            return is_perfect_match

    def _match_onnx_attribute_type(
        self,
        attribute_name: str,
        attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor,
        is_sequence: bool = False,
    ) -> bool:
        if isinstance(attribute, (int, float, bool, str)):
            attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
                type(attribute), is_sequence=is_sequence
            )
            if attribute_onnx_type != self.attributes[attribute_name].type:
                return False
        # If the attribute is an empty list, we don't know the type of the list
        # so it's a mismatch
        elif isinstance(attribute, (list, tuple)) and attribute:
            return self._match_onnx_attribute_type(
                attribute_name, attribute[0], is_sequence=True
            )
        else:
            # NOTE: Unrecognized attribute type
            return False
        return True

    def _record_matching_score(
        self,
        inputs: Sequence[
            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
        ],
        attributes: dict[str, fx_type_utils.Argument],
    ):
        """Calculate the inputs matching score of the OpSchema requirements to find the nearest match.

        Only the functions which have the same number of inputs and attributes as the
        OpSchema are eligible to be a nearest match candidate. Thus, we don't need to
        check the length of inputs and attributes here, and only check the types of
        inputs and attributes.

        How the matchsing score is calculated:
            score += 1 if one input/attribute type is in the type constraints.

        Limitations:
            None/NoeType/[] could result in zero matches, and the same score of overloads,
            which will be recorded in SARIF.

        Args:
            inputs: The input arguments.
            attributes: The input keyword arguments.

        Returns:
            True if the inputs match the requirements, False otherwise.
        """
        self._matching_score = 0
        # If they have different length of arguments, the score would be lower to those
        # functions which have the same length of arguments.
        for schema_input, torch_input in zip(self.op_schema.inputs, inputs):
            torch_input_compatible_types = _find_onnx_data_type(torch_input)
            allowed_types = self.type_constraints[schema_input.type_str]
            if allowed_types.intersection(torch_input_compatible_types):
                # If torch_input_compatible_types is in allowed_types
                # of this input defined in the OpSchema, we know the function
                # and the input are compatible
                self._matching_score += 1
        # NOTE: The penalty is applied to those functions which have different attributes.
        for attribute_name, attribute_proto in self.attributes.items():
            attribute = attributes[attribute_name]
            attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
                type(attribute)
            )
            if attribute_onnx_type != attribute_proto.type:
                # If the attribute type of the OpSchema and the attribute type don't match,
                # we know the function and the input are not compatible
                self._matching_score -= 1

    # NOTE: Referenced from onnxscript internal function.
    # Importing this function makes the code less robust, as it is not a public API.

    def _separate_input_attributes_from_arguments(
        self,
        param_schemas: Sequence[onnxscript.values.ParamSchema],
        args: Sequence[
            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
        ],
        kwargs: dict[str, fx_type_utils.Argument],
        fill_defaults: bool = True,
    ) -> tuple[list[Any], dict[str, Any]]:
        """Separate Python args and kwargs into ONNX inputs and attributes.

        Extra_kwargs are ignored if their values are None. For example, if the
        OpSchema has an attribute "rounding_mode" and the caller provides
        "rounding_mode=None", the attribute "rounding_mode" will not be included
        in the returned attributes when the OnnxFunction signature doesn't have
        "rounding_mode" as an attribute.

        Args:
            param_schemas: The parameter schemas of an Op or a OnnxFunction.
            args: The Python positional arguments supplied by the caller.
            kwargs: The Python keyword arguments supplied by the caller.
            fill_defaults: Whether to fill the default values for attributes.

        Returns:
            A tuple of two elements:
            - A list of ONNX inputs.
            - An dictionary of ONNX attribute names and values.

        Raises:
            TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
            TypeError: When a required input is not provided.
        """
        # args, kwargs and param_schemas should be all in order
        # user may not specify all inputs or attributes

        import onnx

        onnx_inputs: list[Any] = []
        onnx_attributes: dict[str, Any] = {}
        # NOTE: We need to copy kwargs because we will mutate it
        copy_kwargs = kwargs.copy()
        for i, param in enumerate(param_schemas):
            if param.is_variadic_input:
                # Exhaust all remaining args
                onnx_inputs.extend(args[i:])
                args = []
                continue
            if i < len(args):
                if param.is_input:
                    onnx_inputs.append(args[i])
                else:
                    onnx_attributes[param.name] = args[i]
            elif param.name in copy_kwargs:
                if param.is_input:
                    # Move the input from kwargs to inputs
                    onnx_inputs.append(copy_kwargs[param.name])
                    copy_kwargs.pop(param.name)
                else:
                    onnx_attributes[param.name] = copy_kwargs[param.name]
            elif (
                param.is_attribute
                and self.attributes[param.name].default_value.type
                != onnx.AttributeProto.UNDEFINED  # type: ignore[attr-defined]
            ):
                # User did not provide the attribute
                if fill_defaults:
                    onnx_attributes[param.name] = param.default
            # optional input
            elif param.is_input:
                if fill_defaults:
                    onnx_inputs.append(None)

        # NOTE: Pick up extra kwargs if it's not None. None is not expected
        # as an attribute value in torchlib.
        for k, v in copy_kwargs.items():
            if k not in onnx_attributes and v is not None:
                onnx_attributes[k] = v
        return onnx_inputs, onnx_attributes


def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
    """Check if the node has complex dtype recursively."""
    if (
        isinstance(arg, torch.fx.Node)
        and "val" in arg.meta
        and isinstance(arg.meta["val"], torch.Tensor)
        and torch.is_complex(arg.meta["val"])
    ):
        return True
    elif isinstance(arg, list):
        for item in arg:
            return _is_arg_with_complex_dtype(item)
    return False


def _find_onnx_data_type(
    torch_input: fx_type_utils.TensorLike
    | str
    | int
    | float
    | bool
    | list
    | tuple
    | complex
    | None,
) -> set[str]:
    """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string."""
    if (
        isinstance(torch_input, fx_type_utils.TensorLike)
        and torch_input.dtype is not None
    ):
        return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype)
    if isinstance(torch_input, (int, float, bool, str, complex)):
        return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input))
    if isinstance(torch_input, (list, tuple)) and torch_input:  # [Tensor, Tensor]
        the_first_non_none_item = next(
            (item for item in torch_input if item is not None), None
        )
        set_dtype = _find_onnx_data_type(the_first_non_none_item)
        if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input):
            # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type))
            return {f"seq({dtype})" for dtype in set_dtype}
        else:
            # constant list of non-tensor type
            return set_dtype
    if (
        torch_input is None
        or (
            isinstance(torch_input, fx_type_utils.TensorLike)
            and torch_input.dtype is None
        )
        or (isinstance(torch_input, (list, tuple)) and not torch_input)
    ):
        # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check
        # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case.
        return set()

    raise RuntimeError(f"Unknown input type from input: {torch_input}")