1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import dataclasses 5 6from torch.onnx._internal.fx import _pass, diagnostics, registration 7 8 9@dataclasses.dataclass 10class UnsupportedFxNodesAnalysisResult(_pass.AnalysisResult): 11 unsupported_op_to_target_mapping: dict[str, dict[str, None]] 12 13 14class UnsupportedFxNodesAnalysis(_pass.Analysis): 15 """An analysis that detects unsupported FX nodes in the graph.""" 16 17 def _lint( 18 self, 19 analysis_result: UnsupportedFxNodesAnalysisResult, 20 diagnostic_level: diagnostics.infra.Level, 21 ): 22 """Lint the graph and emit diagnostics if unsupported FX nodes are found.""" 23 if not analysis_result.unsupported_op_to_target_mapping: 24 return 25 26 normalized_op_targets_map = { 27 op: list(targets.keys()) 28 for op, targets in analysis_result.unsupported_op_to_target_mapping.items() 29 } 30 31 rule = diagnostics.rules.unsupported_fx_node_analysis 32 diagnostic = diagnostics.Diagnostic( 33 rule, 34 level=diagnostic_level, 35 message=rule.format_message(normalized_op_targets_map), 36 ) 37 self.diagnostic_context.log_and_raise_if_error(diagnostic) 38 39 def analyze( 40 self, diagnostic_level: diagnostics.infra.Level 41 ) -> UnsupportedFxNodesAnalysisResult: 42 """Analyze the graph, emit diagnostics and return a result that contains unsupported FX nodes. 43 44 Args: 45 diagnostic_level: The diagnostic level to use when emitting diagnostics. 46 47 Returns: 48 An analysis result that contains unsupported FX nodes. 49 50 Raises: 51 RuntimeErrorWithDiagnostic: If diagnostics are emitted and the diagnostic 52 level is `ERROR`. 53 """ 54 55 op_to_target_mapping: dict[str, dict[str, None]] = {} 56 for node in self.module.graph.nodes: 57 if node.op == "call_function": 58 # NOTE: OPSchema matcher is not in this analysis scope. 59 internal_opname: registration.OpName = ( 60 self.onnxfunction_dispatcher._get_aten_name( 61 node=node, diagnostic_context=self.diagnostic_context 62 ) 63 ) 64 overload_registration = ( 65 self.onnxfunction_dispatcher.onnx_registry.is_registered_op( 66 namespace=internal_opname.namespace, 67 op_name=internal_opname.op_name, 68 overload=internal_opname.overload, 69 ) 70 ) 71 # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload. 72 default_registration = ( 73 self.onnxfunction_dispatcher.onnx_registry.is_registered_op( 74 namespace=internal_opname.namespace, 75 op_name=internal_opname.op_name, 76 overload=None, 77 ) 78 ) 79 if not overload_registration and not default_registration: 80 op_to_target_mapping.setdefault(node.op, {}).setdefault( 81 str(node.target), None 82 ) 83 84 analysis_result = UnsupportedFxNodesAnalysisResult(op_to_target_mapping) 85 self._lint(analysis_result, diagnostic_level) 86 return analysis_result 87