• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import dataclasses
5
6from torch.onnx._internal.fx import _pass, diagnostics, registration
7
8
9@dataclasses.dataclass
10class UnsupportedFxNodesAnalysisResult(_pass.AnalysisResult):
11    unsupported_op_to_target_mapping: dict[str, dict[str, None]]
12
13
14class UnsupportedFxNodesAnalysis(_pass.Analysis):
15    """An analysis that detects unsupported FX nodes in the graph."""
16
17    def _lint(
18        self,
19        analysis_result: UnsupportedFxNodesAnalysisResult,
20        diagnostic_level: diagnostics.infra.Level,
21    ):
22        """Lint the graph and emit diagnostics if unsupported FX nodes are found."""
23        if not analysis_result.unsupported_op_to_target_mapping:
24            return
25
26        normalized_op_targets_map = {
27            op: list(targets.keys())
28            for op, targets in analysis_result.unsupported_op_to_target_mapping.items()
29        }
30
31        rule = diagnostics.rules.unsupported_fx_node_analysis
32        diagnostic = diagnostics.Diagnostic(
33            rule,
34            level=diagnostic_level,
35            message=rule.format_message(normalized_op_targets_map),
36        )
37        self.diagnostic_context.log_and_raise_if_error(diagnostic)
38
39    def analyze(
40        self, diagnostic_level: diagnostics.infra.Level
41    ) -> UnsupportedFxNodesAnalysisResult:
42        """Analyze the graph, emit diagnostics and return a result that contains unsupported FX nodes.
43
44        Args:
45            diagnostic_level: The diagnostic level to use when emitting diagnostics.
46
47        Returns:
48            An analysis result that contains unsupported FX nodes.
49
50        Raises:
51            RuntimeErrorWithDiagnostic: If diagnostics are emitted and the diagnostic
52                level is `ERROR`.
53        """
54
55        op_to_target_mapping: dict[str, dict[str, None]] = {}
56        for node in self.module.graph.nodes:
57            if node.op == "call_function":
58                # NOTE: OPSchema matcher is not in this analysis scope.
59                internal_opname: registration.OpName = (
60                    self.onnxfunction_dispatcher._get_aten_name(
61                        node=node, diagnostic_context=self.diagnostic_context
62                    )
63                )
64                overload_registration = (
65                    self.onnxfunction_dispatcher.onnx_registry.is_registered_op(
66                        namespace=internal_opname.namespace,
67                        op_name=internal_opname.op_name,
68                        overload=internal_opname.overload,
69                    )
70                )
71                # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
72                default_registration = (
73                    self.onnxfunction_dispatcher.onnx_registry.is_registered_op(
74                        namespace=internal_opname.namespace,
75                        op_name=internal_opname.op_name,
76                        overload=None,
77                    )
78                )
79                if not overload_registration and not default_registration:
80                    op_to_target_mapping.setdefault(node.op, {}).setdefault(
81                        str(node.target), None
82                    )
83
84        analysis_result = UnsupportedFxNodesAnalysisResult(op_to_target_mapping)
85        self._lint(analysis_result, diagnostic_level)
86        return analysis_result
87