File: unsupported_nodes.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (86 lines) | stat: -rw-r--r-- 3,389 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# mypy: allow-untyped-defs
from __future__ import annotations

import dataclasses

from torch.onnx._internal.fx import _pass, diagnostics, registration


@dataclasses.dataclass
class UnsupportedFxNodesAnalysisResult(_pass.AnalysisResult):
    unsupported_op_to_target_mapping: dict[str, dict[str, None]]


class UnsupportedFxNodesAnalysis(_pass.Analysis):
    """An analysis that detects unsupported FX nodes in the graph."""

    def _lint(
        self,
        analysis_result: UnsupportedFxNodesAnalysisResult,
        diagnostic_level: diagnostics.infra.Level,
    ):
        """Lint the graph and emit diagnostics if unsupported FX nodes are found."""
        if not analysis_result.unsupported_op_to_target_mapping:
            return

        normalized_op_targets_map = {
            op: list(targets.keys())
            for op, targets in analysis_result.unsupported_op_to_target_mapping.items()
        }

        rule = diagnostics.rules.unsupported_fx_node_analysis
        diagnostic = diagnostics.Diagnostic(
            rule,
            level=diagnostic_level,
            message=rule.format_message(normalized_op_targets_map),
        )
        self.diagnostic_context.log_and_raise_if_error(diagnostic)

    def analyze(
        self, diagnostic_level: diagnostics.infra.Level
    ) -> UnsupportedFxNodesAnalysisResult:
        """Analyze the graph, emit diagnostics and return a result that contains unsupported FX nodes.

        Args:
            diagnostic_level: The diagnostic level to use when emitting diagnostics.

        Returns:
            An analysis result that contains unsupported FX nodes.

        Raises:
            RuntimeErrorWithDiagnostic: If diagnostics are emitted and the diagnostic
                level is `ERROR`.
        """

        op_to_target_mapping: dict[str, dict[str, None]] = {}
        for node in self.module.graph.nodes:
            if node.op == "call_function":
                # NOTE: OPSchema matcher is not in this analysis scope.
                internal_opname: registration.OpName = (
                    self.onnxfunction_dispatcher._get_aten_name(
                        node=node, diagnostic_context=self.diagnostic_context
                    )
                )
                overload_registration = (
                    self.onnxfunction_dispatcher.onnx_registry.is_registered_op(
                        namespace=internal_opname.namespace,
                        op_name=internal_opname.op_name,
                        overload=internal_opname.overload,
                    )
                )
                # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
                default_registration = (
                    self.onnxfunction_dispatcher.onnx_registry.is_registered_op(
                        namespace=internal_opname.namespace,
                        op_name=internal_opname.op_name,
                        overload=None,
                    )
                )
                if not overload_registration and not default_registration:
                    op_to_target_mapping.setdefault(node.op, {}).setdefault(
                        str(node.target), None
                    )

        analysis_result = UnsupportedFxNodesAnalysisResult(op_to_target_mapping)
        self._lint(analysis_result, diagnostic_level)
        return analysis_result