1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
# mypy: allow-untyped-defs
from __future__ import annotations
import dataclasses
from torch.onnx._internal.fx import _pass, diagnostics, registration
@dataclasses.dataclass
class UnsupportedFxNodesAnalysisResult(_pass.AnalysisResult):
unsupported_op_to_target_mapping: dict[str, dict[str, None]]
class UnsupportedFxNodesAnalysis(_pass.Analysis):
"""An analysis that detects unsupported FX nodes in the graph."""
def _lint(
self,
analysis_result: UnsupportedFxNodesAnalysisResult,
diagnostic_level: diagnostics.infra.Level,
):
"""Lint the graph and emit diagnostics if unsupported FX nodes are found."""
if not analysis_result.unsupported_op_to_target_mapping:
return
normalized_op_targets_map = {
op: list(targets.keys())
for op, targets in analysis_result.unsupported_op_to_target_mapping.items()
}
rule = diagnostics.rules.unsupported_fx_node_analysis
diagnostic = diagnostics.Diagnostic(
rule,
level=diagnostic_level,
message=rule.format_message(normalized_op_targets_map),
)
self.diagnostic_context.log_and_raise_if_error(diagnostic)
def analyze(
self, diagnostic_level: diagnostics.infra.Level
) -> UnsupportedFxNodesAnalysisResult:
"""Analyze the graph, emit diagnostics and return a result that contains unsupported FX nodes.
Args:
diagnostic_level: The diagnostic level to use when emitting diagnostics.
Returns:
An analysis result that contains unsupported FX nodes.
Raises:
RuntimeErrorWithDiagnostic: If diagnostics are emitted and the diagnostic
level is `ERROR`.
"""
op_to_target_mapping: dict[str, dict[str, None]] = {}
for node in self.module.graph.nodes:
if node.op == "call_function":
# NOTE: OPSchema matcher is not in this analysis scope.
internal_opname: registration.OpName = (
self.onnxfunction_dispatcher._get_aten_name(
node=node, diagnostic_context=self.diagnostic_context
)
)
overload_registration = (
self.onnxfunction_dispatcher.onnx_registry.is_registered_op(
namespace=internal_opname.namespace,
op_name=internal_opname.op_name,
overload=internal_opname.overload,
)
)
# NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
default_registration = (
self.onnxfunction_dispatcher.onnx_registry.is_registered_op(
namespace=internal_opname.namespace,
op_name=internal_opname.op_name,
overload=None,
)
)
if not overload_registration and not default_registration:
op_to_target_mapping.setdefault(node.op, {}).setdefault(
str(node.target), None
)
analysis_result = UnsupportedFxNodesAnalysisResult(op_to_target_mapping)
self._lint(analysis_result, diagnostic_level)
return analysis_result
|