• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2"""Dispatcher for AtenLib functions from onnx-script."""
3
4from __future__ import annotations
5
6import logging
7import operator
8import types
9from typing import Any, Callable, Sequence, TYPE_CHECKING
10
11import torch
12import torch._ops
13import torch.fx
14from torch.onnx._internal.fx import (
15    diagnostics,
16    registration,
17    type_utils as fx_type_utils,
18)
19
20
21if TYPE_CHECKING:
22    import onnxscript  # type: ignore[import]
23    from onnxscript.function_libs.torch_lib import (  # type: ignore[import]
24        graph_building as onnxscript_graph_building,
25    )
26
27    from torch.onnx import OnnxRegistry
28
29
30def _find_opschema_matched_symbolic_function_disagnostic_message_formatter(
31    fn: Callable,
32    self,
33    node: torch.fx.Node,
34    default_and_custom_functions: list[registration.ONNXFunction],
35    *args,
36    **kwargs,
37) -> str:
38    """Format the diagnostic message for the nearest match warning."""
39    all_function_overload_names = ""
40    for symbolic_func in default_and_custom_functions:
41        overload_func = symbolic_func.onnx_function
42        all_function_overload_names += f"ONNX Node: {overload_func.name}[opset={overload_func.opset};is_custom={symbolic_func.is_custom}]. \n"  # noqa: B950
43    return f"FX Node: {node.target}. \n" f"{all_function_overload_names}"
44
45
46def _find_operator_overloads_in_onnx_registry_disagnostic_message_formatter(
47    fn: Callable,
48    self,
49    node: torch.fx.Node,
50    *args,
51    **kwargs,
52) -> str:
53    """Format the diagnostic message for the nearest match warning."""
54    return f"Searching operator overload: '{node.target}' in onnx registry...\n"
55
56
57class OnnxFunctionDispatcher:
58    """A dispatcher that finds the best ONNX Function for ATen/Custom operators.
59
60    It uses the `torch.ops` name to find the function. If not found, it falls back to default.
61    Otherwise, the best match is found among all function overloads. An exact match has
62    higher precedence over the closest ones.
63
64    Below is a breakdown on how the dispatch mechanism works:
65
66    1. Use the torch.ops name to find the function:
67        a. Check if the ATen overload exists in the registry.
68        b. If not, check if the default overload exists in the registry.
69
70    2. Find the nearest match among all overloaded functions:
71        a. If the types match perfectly, select the function.
72        b. Otherwise, find the nearest one with the highest matching score. Because of
73            the potential wrongly annotated dtypes and attributes matching, we use
74            nearest match to find the best function once the aten name is targeted.
75
76    3. Tie-breaker: If there are multiple nearest matches, we will select the one with
77        the highest matching score.
78
79    NOTE: The nearest match `doesn't guarantee` a correct match, and a warning message is logged.
80    """
81
82    def __init__(
83        self,
84        onnx_registry: OnnxRegistry,
85        diagnostic_context: diagnostics.DiagnosticContext,
86    ):
87        """Initialize the ONNX Function dispatcher.
88
89        Args:
90            onnx_registry: The ONNX registry.
91            diagnostic_context: The diagnostic context to use for reporting errors.
92        """
93        self.onnx_registry = onnx_registry
94        self.diagnostic_context = diagnostic_context
95
96    def dispatch(
97        self,
98        node: torch.fx.Node,
99        onnx_args: Sequence[
100            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
101        ],
102        onnx_kwargs: dict[str, fx_type_utils.Argument],
103        diagnostic_context: diagnostics.DiagnosticContext,
104    ) -> onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction:
105        """Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments.
106        Args:
107            node: The TorchFX node to dispatch the function for.
108            onnx_args: The arguments of the ONNX function.
109            onnx_kwargs: The keyword arguments of the ONNX function.
110            diagnostic_context: The diagnostic context to use for reporting errors.
111        Returns:
112            Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm.
113        Raises:
114            RuntimeError: If there are no overloaded functions available for the given FX node.
115        """
116        # If there are no overloaded functions available for the given FX node, raise an
117        # unsupported error
118        default_and_custom_functions = self.get_function_overloads(
119            node, diagnostic_context
120        )
121
122        # If there are overloaded functions available, we will find one that perfect or
123        # nearest matches the given arguments and keyword arguments
124        return self._find_the_perfect_or_nearest_match_onnxfunction(
125            node,
126            default_and_custom_functions,
127            onnx_args,
128            onnx_kwargs,
129            diagnostic_context,
130        )
131
132    def _filter_or_keep_complex(
133        self,
134        node,
135        default_and_custom_functions: list[registration.ONNXFunction],
136        diagnostic_context: diagnostics.DiagnosticContext,
137    ) -> list[registration.ONNXFunction]:
138        """Filter the complex functions if the input has complex dtype."""
139
140        args_with_complex_dtype = [_is_arg_with_complex_dtype(arg) for arg in node.args]
141        if any(args_with_complex_dtype):
142            default_and_custom_functions = [
143                func for func in default_and_custom_functions if func.is_complex
144            ]
145            # If we can't find the complex function group, raise error.
146            if not default_and_custom_functions:
147                op_full_name = self._get_aten_name(
148                    node, diagnostic_context
149                ).qualified_name()
150                diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
151                    diagnostics.rules.no_symbolic_function_for_call_function,
152                    diagnostics.levels.ERROR,
153                    f"Cannot find any COMPLEX symbolic function for {op_full_name}, "
154                    f"which should be registered under {node.target}.",
155                    unsupported_fx_node=node,
156                )
157                diagnostic_context.log(diagnostic)
158                raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
159        else:
160            default_and_custom_functions = [
161                func for func in default_and_custom_functions if not func.is_complex
162            ]
163            # If we can't find the complex function group, raise error.
164            if not default_and_custom_functions:
165                op_full_name = self._get_aten_name(
166                    node, diagnostic_context
167                ).qualified_name()
168                diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
169                    diagnostics.rules.no_symbolic_function_for_call_function,
170                    diagnostics.levels.ERROR,
171                    f"Can ONLY find COMPLEX symbolic function for {op_full_name}, "
172                    f"which should be registered under {node.target}.",
173                    unsupported_fx_node=node,
174                )
175                diagnostic_context.log(diagnostic)
176                raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
177        return default_and_custom_functions
178
179    @diagnostics.diagnose_call(
180        diagnostics.rules.find_opschema_matched_symbolic_function,
181        diagnostic_message_formatter=_find_opschema_matched_symbolic_function_disagnostic_message_formatter,
182    )
183    def _find_the_perfect_or_nearest_match_onnxfunction(
184        self,
185        node: torch.fx.Node,  # this is used in diagnostic_message_formatter
186        default_and_custom_functions: list[registration.ONNXFunction],
187        onnx_args: Sequence[
188            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
189        ],
190        onnx_kwargs: dict[str, fx_type_utils.Argument],
191        diagnostic_context: diagnostics.DiagnosticContext,
192    ):
193        """Find the perfect/nearest matched OnnxFunction for the given FX node, arguments, and keyword arguments.
194
195        Args:
196            default_and_custom_functions: The list includes overloaded functions, with
197                custom ones appearing after the default ones.
198            onnx_args: Arguments organized in PyTorch inputs way.
199            onnx_kwargs: Keyword arguments organized in PyTorch inputs way.
200            diagnostic_context: The diagnostic context to use for reporting errors.
201
202            Returns:
203                Either an `onnxscript.OnnxFunction` or `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm.
204            Raises:
205                RuntimeError: If there are no overloaded functions available for the given FX node.
206        """
207        overload_match_ranking: dict[registration.ONNXFunction, int | None] = {}
208        diagnostic = diagnostic_context.inflight_diagnostic()
209
210        # Iterate the overloaded functions in reverse order to prioritize the custom ones
211        # over the default ones, and find the perfect match.
212        for symbolic_function in reversed(default_and_custom_functions):
213            function_opschema = _OnnxSchemaChecker(symbolic_function.onnx_function)
214
215            # NOTE: 1. If the perfect match is found, return the function
216            if function_opschema.perfect_match_inputs(
217                diagnostic, onnx_args, onnx_kwargs
218            ):
219                return symbolic_function.onnx_function
220            # Record the match score for the nearest match if it's not the perfect match
221            overload_match_ranking[symbolic_function] = function_opschema.match_score
222
223        # NOTE: 2. If there is no perfect match, find the nearest match among the nearest matche candidates
224        # If there is no nearest match, raise an error
225        overload_match_ranking = {
226            k: v for k, v in overload_match_ranking.items() if v is not None
227        }
228        if not overload_match_ranking:
229            # If there are no overloaded functions available for the given FX node, raise an
230            # unsupported error
231            op_full_name = self._get_aten_name(
232                node, diagnostic_context
233            ).qualified_name()
234            diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
235                diagnostics.rules.no_symbolic_function_for_call_function,
236                diagnostics.levels.ERROR,
237                f"Cannot find any perfect/nearest match of symbolic function for {op_full_name},"
238                f"which should be registered under {node.target}.",
239                unsupported_fx_node=node,
240            )
241            diagnostic_context.log(diagnostic)
242            raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
243
244        diagnostic.warning(
245            "### Exact match is not found!\n"
246            "Cannot find a perfect match of symbolic overload, "
247            "a nearest match is found. Please check the ONNX output carefully. \n",
248        )
249        diagnostic.level = diagnostics.levels.WARNING
250        # NOTE: 3. Tie breaker: if there are multiple nearest matches, we will choose the one
251        # that is custom first. If there are multiple custom ones, we will choose the one
252        # that is added lastly in the list.
253        symbolic_function_list: list[registration.ONNXFunction] = sorted(
254            overload_match_ranking,
255            key=lambda k: (
256                overload_match_ranking[k],
257                k.is_custom,
258                default_and_custom_functions.index(k),
259            ),
260            reverse=True,
261        )
262        return symbolic_function_list[0].onnx_function
263
264    def _get_aten_name(
265        self, node: torch.fx.Node, diagnostic_context: diagnostics.DiagnosticContext
266    ) -> registration.OpName:
267        """Get the OpName from the target.
268
269        Args:
270            node: The TorchFX node to get the aten name for.
271            diagnostic_context: The diagnostic context to use for reporting errors.
272
273        Returns:
274            The internal op name within dataclass: registration.OpName.
275        """
276        if node.target == operator.getitem:
277            return registration.OpName.from_name_parts(
278                namespace="aten", op_name="getitem"
279            )
280        if isinstance(node.target, torch._ops.OpOverloadPacket):
281            # aten::sym_size is the only OverloadPacket that we support.
282            # schema: aten::sym_size(Tensor self, int dim) -> Tensor
283            if node.target != torch.ops.aten.sym_size:
284                diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
285                    diagnostics.rules.no_symbolic_function_for_call_function,
286                    diagnostics.levels.ERROR,
287                    f"Unsupported OverloadPacket: {node.target}, aten.sym_size is the only allowed OverloadPacket!",
288                    unsupported_fx_node=node,
289                )
290                diagnostic_context.log(diagnostic)
291                raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
292            # TODO(titaiwang): aten::sym_size has overload, but fx graph is using
293            # overloadpacket for some reasons.
294            # https://github.com/pytorch/pytorch/issues/97201
295            aten_op_default = node.target.default
296            return registration.OpName.from_op_overload(op_overload=aten_op_default)  # type: ignore[no-any-return]
297
298        if isinstance(node.target, types.BuiltinFunctionType):
299            # Make sure it's symint/symfloat consuming builtin ops.
300            for node_arg in node.args:
301                if (not isinstance(node_arg, (torch.fx.Node, int, float))) or (
302                    isinstance(node_arg, torch.fx.Node)
303                    and not fx_type_utils.is_torch_symbolic_type(node_arg.meta["val"])
304                ):
305                    diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
306                        diagnostics.rules.no_symbolic_function_for_call_function,
307                        diagnostics.levels.ERROR,
308                        f"Unsupported node arg: {node_arg} (type {type(node_arg)}) with builtin function: {node.target},"
309                        " only int/float/SymInt/SymFloat is supported with built-in ops!",
310                        unsupported_fx_node=node,
311                    )
312                    diagnostic_context.log(diagnostic)
313                    raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
314            return registration.OpName.from_builtin_function(node.target)
315
316        if isinstance(node.target, torch._ops.OpOverload):
317            return registration.OpName.from_op_overload(op_overload=node.target)
318
319        # Unexpected target, raise error.
320        diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
321            diagnostics.rules.no_symbolic_function_for_call_function,
322            diagnostics.levels.ERROR,
323            f"Unknown call_function target: {node.target}",
324            unsupported_fx_node=node,
325        )
326        diagnostic_context.log(diagnostic)
327        raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
328
329    @diagnostics.diagnose_call(
330        diagnostics.rules.find_operator_overloads_in_onnx_registry,
331        diagnostic_message_formatter=_find_operator_overloads_in_onnx_registry_disagnostic_message_formatter,
332    )
333    def get_function_overloads(
334        self,
335        node: torch.fx.Node,
336        diagnostic_context: diagnostics.DiagnosticContext,
337    ) -> list[registration.ONNXFunction]:
338        """Get the function overloads from the registry.
339
340        Args:
341            node: The node to get the function overloads for.
342            diagnostic_context: The diagnostic context to use for reporting errors.
343
344        Returns:
345            The list contains ONNXFunctions, starting with the default ones and
346            followed by any custom ones.
347        """
348
349        internal_opname: registration.OpName = self._get_aten_name(
350            node=node, diagnostic_context=diagnostic_context
351        )
352
353        # If the ATen/Custom operators are not registered, the group will be None.
354        # And non-registered ATen/Custom operators will trigger error in the next step.
355        function_group: list[registration.ONNXFunction] | None = None
356
357        function_group = self.onnx_registry.get_op_functions(
358            namespace=internal_opname.namespace,
359            op_name=internal_opname.op_name,
360            overload=internal_opname.overload,
361        )
362
363        # NOTE: Fall back to default overload if the ONNX registry doesn't have the overload.
364        if function_group is None:
365            function_group = self.onnx_registry.get_op_functions(
366                namespace=internal_opname.namespace,
367                op_name=internal_opname.op_name,
368                overload=None,
369            )
370            if function_group is not None:
371                op_full_name = internal_opname.qualified_name()
372                diagnostic = diagnostic_context.inflight_diagnostic()
373                diagnostic.warning(
374                    "### The operator overload is not found in onnx registry!\n"
375                    "Cannot find the operator overload in onnx registry, but "
376                    "the default overload is found. Please check the ONNX output carefully. \n",
377                )
378                diagnostic.level = diagnostics.levels.WARNING
379
380        if function_group is not None:
381            # NOTE: If the input has complex dtype, we will only dispatch to the complex functions.
382            function_group = self._filter_or_keep_complex(
383                node, function_group, diagnostic_context
384            )
385            return function_group  # type: ignore[return-value]
386
387        op_full_name = internal_opname.qualified_name()
388        diagnostic = diagnostics.UnsupportedFxNodeDiagnostic(
389            diagnostics.rules.no_symbolic_function_for_call_function,
390            diagnostics.levels.ERROR,
391            f"Cannot find symbolic function for {op_full_name}, "
392            f"which should be registered under {node.target}.",
393            unsupported_fx_node=node,
394        )
395        diagnostic_context.log(diagnostic)
396        raise diagnostics.RuntimeErrorWithDiagnostic(diagnostic)
397
398
399class _OnnxSchemaChecker:
400    """
401    The OnnxSchemaChecker class is a checker for ONNX OpSchema and param schema.
402
403    It provides methods to check for input compatibility based on the OpSchema. It also
404    provides a matching score to indicate how well the OpSchema matches the input and
405    kwargs types. A function will be evaluated as perfect match, nearest match eligible,
406    or no match.
407
408    Here are some common examples in categories:
409
410    1. [NOTE: Perfect match]: The number of inputs and attributes are exactly the same as
411        the OpSchema. The types of inputs and attributes are exactly the same as the
412        OpSchema.
413
414        ```python
415        inputs = (Tensor[2, 3], Tensor[2, 3])
416        attributes = {"alpha": 1.0}
417
418
419        @torch_op("aten::op")
420        def aten_op(self: TReal, other: TReal, alpha: float = 1) -> TReal: ...
421        ```
422        Result: Perfect match.
423
424    2. [NOTE: Optional input]: The dispatcher recognizes optional inputs. However,
425        the input can't be ignored. None must be provided.
426
427        ```python
428        inputs = (Tensor([2, 3]), None)
429        attributes = {}
430
431        aten_op(X: TTensor, Y: Optional[INT64]):
432            ...
433        ```
434        Result: Perfect match.
435        Real example: `aten::convolution`.
436
437    3. [NOTE: Different attributes]: If an attribute is provided with value, it's
438        a must to match the attribute in function signature.
439        ```python
440        inputs = (Tensor([2, 3]),)
441        attributes = {"a":1, "b":2}
442
443        aten_op(X: TTensor, a: int):
444            ...
445        ```
446        Result: No match.
447        Real example: `aten::div` vs `aten::div.Tensor_mode`.
448
449    4. [NOTE: Default attributes]: Default attribute will fill in the value into
450        inputs/attributes.
451        ```python
452        inputs = (Tensor([2, 3]),)
453        attributes = {}
454
455        aten_op(X: TTensor, a: int = 3):
456            ...
457        ```
458        Result: Perfect match.
459        Real example: `aten::clone`
460
461    5. [NOTE: Ignore attribute with None value]: The attributes with None value
462        will be ignored in matching.
463        ```python
464        inputs = (Tensor([2, 3]),)
465        attributes = {"a": None}
466
467        aten_op(X: TTensor):
468            ...
469        ```
470        Result: Perfect match.
471
472        ```python
473        inputs = (Tensor([2, 3]),)
474        attributes = {"a": None}
475
476        aten_op(X: TTensor, a: int = 3):
477            ...
478        ```
479        Result: Nearest match eligible.
480
481        Real example: `aten::div` vs `aten::div.Tensor_mode`.
482
483    Attributes:
484        onnxfunction: The OnnxFunction.
485        param_schema: The parameter schema defined in the OnnxFunction.
486        op_schema: The ONNX OpSchema.
487        type_constraints: The type constraints defined in the OpSchema.
488        attributes: The attributes defined in the OpSchema.
489        _matching_score: The matching score of the OnnxSchemaChecker .
490
491    """
492
493    def __init__(
494        self,
495        onnxfunction: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction,
496    ):
497        """Initialize the OnnxSchemaChecker .
498
499        Args:
500            onnxfunction: The OnnxFunction.
501        """
502        self.onnxfunction = onnxfunction
503        self.param_schema = self.onnxfunction.param_schemas()
504        op_schema = self.onnxfunction.op_schema
505        # Both `OnnxFunction` and `TracedOnnxFunction` never return None for `op_schema`.
506        # However their base class would. Hence return type is annotated as Optional[OpSchema].
507        assert op_schema is not None
508        self.op_schema = op_schema
509        self.type_constraints = {
510            # "T": {"tensor(int64)"}
511            constraint.type_param_str: set(constraint.allowed_type_strs)
512            for constraint in self.op_schema.type_constraints
513        }
514        self.attributes = self.op_schema.attributes
515        self._matching_score: int | None = None
516
517    @property
518    def match_score(self) -> int | None:
519        """The matching score of the OnnxSchemaChecker .
520
521        If this remains None, it means the matching score has not been calculated,
522        and it's not a nearest match candidate.
523
524        Returns:
525            The matching score of the OnnxSchemaChecker .
526        """
527        return self._matching_score
528
529    def perfect_match_inputs(
530        self,
531        diagnostic: diagnostics.Diagnostic,
532        args: Sequence[
533            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
534        ],
535        kwargs: dict[str, fx_type_utils.Argument],
536    ) -> bool:
537        """Check if the inputs perfectly match the OpSchema requirements.
538
539        The definition of perfect match is that the input types are all in the type
540        constraints and the number of inputs matches the number of inputs in the
541        OpSchema.
542
543        Checking steps:
544        1. The function signature matches the inputs number, and attribute names.
545        2. The input/attribute types are all in the type constraints.
546
547        A function should at least pass the first step to be eligible for the
548        nearest matching.
549
550        Args:
551            diagnostic: The diagnostic to use for logging detailed info.
552            args: The input arguments organized in PyTorch inputs way.
553            kwargs: The input keyword arguments organized in PyTorch inputs way.
554
555        Returns:
556            True if the inputs match the requirements, False otherwise.
557        """
558
559        # NOTE: OnnxFunction does not have the same function signature as the original
560        # PyTorch operator. We need to separate the input/attributes from the arguments.
561        (
562            function_inputs,
563            function_attributes,
564        ) = self._separate_input_attributes_from_arguments(
565            self.param_schema,
566            args,
567            kwargs,
568            fill_defaults=True,  # fill defaults for optional arguments to match
569        )
570        with diagnostic.log_section(logging.INFO, "Checking perfect match..."):
571            diagnostic.info(
572                "%s",
573                diagnostics.LazyString(diagnostics.format_argument, self.onnxfunction),
574            )
575            # NOTE: 1. Check if the input number and attribute names match the
576            # OpSchema. If it's not, we know the function is not eligible to be a perfect
577            # match, nor a nearest match.
578            # We use is_perfect_match to postpone the return value to the end
579            # of the function, as we want to log all the mismatch info.
580            is_perfect_match = True
581            if len(function_inputs) != len(self.op_schema.inputs):
582                with diagnostic.log_section(
583                    logging.INFO, "Failed: input number mismatch!"
584                ):
585                    diagnostic.info(
586                        "Actual %d vs expected %d",
587                        len(function_inputs),
588                        len(self.op_schema.inputs),
589                    )
590                diagnostic.info("The function is not a nearest match candidate.")
591                is_perfect_match = False
592
593            if set(function_attributes) != set(self.attributes):
594                with diagnostic.log_section(
595                    logging.INFO, "Failed: attribute mismatch!"
596                ):
597                    diagnostic.info(
598                        "%s",
599                        diagnostics.LazyString(
600                            lambda: f"Actual {set(function_attributes)} vs expected {set(self.attributes)}",
601                        ),
602                    )
603                diagnostic.info("The function is not a nearest match candidate.")
604                is_perfect_match = False
605
606            # If it's already not a perfect match, we can return False directly. Further
607            # checking is only for the functions that are eligible for nearest match.
608            if not is_perfect_match:
609                return False
610
611            # NOTE: 2. The dtypes of inputs and attributes should be in the
612            # type constraints of the OpSchema. If they are not, we know the function is not
613            # eligible to be a perfect match, but can be a nearest match candidate.
614            for schema_input, torch_input in zip(
615                self.op_schema.inputs, function_inputs
616            ):
617                torch_input_compatible_types = _find_onnx_data_type(torch_input)
618                allowed_types = self.type_constraints[schema_input.type_str]
619                if not allowed_types.intersection(
620                    torch_input_compatible_types
621                ) and not any(
622                    fx_type_utils.is_optional_onnx_dtype_str(onnx_type_str)
623                    for onnx_type_str in allowed_types
624                ):
625                    # If torch_input_compatible_types isn't in allowed_types
626                    # of this input defined in the OpSchema, we know the function
627                    # and the input are not compatible
628                    with diagnostic.log_section(
629                        logging.INFO,
630                        "Failed: input type mismatch for input '%s'!",
631                        schema_input.name,
632                    ):
633                        diagnostic.info(
634                            "Actual %s vs\nExpected %s",
635                            torch_input_compatible_types,
636                            allowed_types,
637                        )
638                    is_perfect_match = False
639
640            for attribute_name, attribute in function_attributes.items():
641                if not self._match_onnx_attribute_type(attribute_name, attribute):
642                    # If the attribute type of the OpSchema and the attribute type don't match,
643                    # we know the function and the input are not compatible
644                    with diagnostic.log_section(
645                        logging.INFO,
646                        "Failed: attribute '%s' type mismatch!",
647                        attribute_name,
648                    ):
649                        diagnostic.info(
650                            "Actual %s vs\nExpected %s",
651                            type(attribute),
652                            self.attributes[attribute_name].type,
653                        )
654                    is_perfect_match = False
655
656            # NOTE: This is still a candidate for nearest match, as it only mismatches attributes on dtype.
657            self._record_matching_score(function_inputs, function_attributes)
658            diagnostic.info("match score: %d", self.match_score)
659            return is_perfect_match
660
661    def _match_onnx_attribute_type(
662        self,
663        attribute_name: str,
664        attribute: fx_type_utils.Argument | onnxscript_graph_building.TorchScriptTensor,
665        is_sequence: bool = False,
666    ) -> bool:
667        if isinstance(attribute, (int, float, bool, str)):
668            attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
669                type(attribute), is_sequence=is_sequence
670            )
671            if attribute_onnx_type != self.attributes[attribute_name].type:
672                return False
673        # If the attribute is an empty list, we don't know the type of the list
674        # so it's a mismatch
675        elif isinstance(attribute, (list, tuple)) and attribute:
676            return self._match_onnx_attribute_type(
677                attribute_name, attribute[0], is_sequence=True
678            )
679        else:
680            # NOTE: Unrecognized attribute type
681            return False
682        return True
683
684    def _record_matching_score(
685        self,
686        inputs: Sequence[
687            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
688        ],
689        attributes: dict[str, fx_type_utils.Argument],
690    ):
691        """Calculate the inputs matching score of the OpSchema requirements to find the nearest match.
692
693        Only the functions which have the same number of inputs and attributes as the
694        OpSchema are eligible to be a nearest match candidate. Thus, we don't need to
695        check the length of inputs and attributes here, and only check the types of
696        inputs and attributes.
697
698        How the matchsing score is calculated:
699            score += 1 if one input/attribute type is in the type constraints.
700
701        Limitations:
702            None/NoeType/[] could result in zero matches, and the same score of overloads,
703            which will be recorded in SARIF.
704
705        Args:
706            inputs: The input arguments.
707            attributes: The input keyword arguments.
708
709        Returns:
710            True if the inputs match the requirements, False otherwise.
711        """
712        self._matching_score = 0
713        # If they have different length of arguments, the score would be lower to those
714        # functions which have the same length of arguments.
715        for schema_input, torch_input in zip(self.op_schema.inputs, inputs):
716            torch_input_compatible_types = _find_onnx_data_type(torch_input)
717            allowed_types = self.type_constraints[schema_input.type_str]
718            if allowed_types.intersection(torch_input_compatible_types):
719                # If torch_input_compatible_types is in allowed_types
720                # of this input defined in the OpSchema, we know the function
721                # and the input are compatible
722                self._matching_score += 1
723        # NOTE: The penalty is applied to those functions which have different attributes.
724        for attribute_name, attribute_proto in self.attributes.items():
725            attribute = attributes[attribute_name]
726            attribute_onnx_type = fx_type_utils.from_python_type_to_onnx_attribute_type(
727                type(attribute)
728            )
729            if attribute_onnx_type != attribute_proto.type:
730                # If the attribute type of the OpSchema and the attribute type don't match,
731                # we know the function and the input are not compatible
732                self._matching_score -= 1
733
734    # NOTE: Referenced from onnxscript internal function.
735    # Importing this function makes the code less robust, as it is not a public API.
736
737    def _separate_input_attributes_from_arguments(
738        self,
739        param_schemas: Sequence[onnxscript.values.ParamSchema],
740        args: Sequence[
741            fx_type_utils.TensorLike | str | int | float | bool | list | complex | None
742        ],
743        kwargs: dict[str, fx_type_utils.Argument],
744        fill_defaults: bool = True,
745    ) -> tuple[list[Any], dict[str, Any]]:
746        """Separate Python args and kwargs into ONNX inputs and attributes.
747
748        Extra_kwargs are ignored if their values are None. For example, if the
749        OpSchema has an attribute "rounding_mode" and the caller provides
750        "rounding_mode=None", the attribute "rounding_mode" will not be included
751        in the returned attributes when the OnnxFunction signature doesn't have
752        "rounding_mode" as an attribute.
753
754        Args:
755            param_schemas: The parameter schemas of an Op or a OnnxFunction.
756            args: The Python positional arguments supplied by the caller.
757            kwargs: The Python keyword arguments supplied by the caller.
758            fill_defaults: Whether to fill the default values for attributes.
759
760        Returns:
761            A tuple of two elements:
762            - A list of ONNX inputs.
763            - An dictionary of ONNX attribute names and values.
764
765        Raises:
766            TypeError: When allow_extra_kwargs is False and there are unknown kwargs.
767            TypeError: When a required input is not provided.
768        """
769        # args, kwargs and param_schemas should be all in order
770        # user may not specify all inputs or attributes
771
772        import onnx
773
774        onnx_inputs: list[Any] = []
775        onnx_attributes: dict[str, Any] = {}
776        # NOTE: We need to copy kwargs because we will mutate it
777        copy_kwargs = kwargs.copy()
778        for i, param in enumerate(param_schemas):
779            if param.is_variadic_input:
780                # Exhaust all remaining args
781                onnx_inputs.extend(args[i:])
782                args = []
783                continue
784            if i < len(args):
785                if param.is_input:
786                    onnx_inputs.append(args[i])
787                else:
788                    onnx_attributes[param.name] = args[i]
789            elif param.name in copy_kwargs:
790                if param.is_input:
791                    # Move the input from kwargs to inputs
792                    onnx_inputs.append(copy_kwargs[param.name])
793                    copy_kwargs.pop(param.name)
794                else:
795                    onnx_attributes[param.name] = copy_kwargs[param.name]
796            elif (
797                param.is_attribute
798                and self.attributes[param.name].default_value.type
799                != onnx.AttributeProto.UNDEFINED  # type: ignore[attr-defined]
800            ):
801                # User did not provide the attribute
802                if fill_defaults:
803                    onnx_attributes[param.name] = param.default
804            # optional input
805            elif param.is_input:
806                if fill_defaults:
807                    onnx_inputs.append(None)
808
809        # NOTE: Pick up extra kwargs if it's not None. None is not expected
810        # as an attribute value in torchlib.
811        for k, v in copy_kwargs.items():
812            if k not in onnx_attributes and v is not None:
813                onnx_attributes[k] = v
814        return onnx_inputs, onnx_attributes
815
816
817def _is_arg_with_complex_dtype(arg: fx_type_utils.Argument) -> bool:
818    """Check if the node has complex dtype recursively."""
819    if (
820        isinstance(arg, torch.fx.Node)
821        and "val" in arg.meta
822        and isinstance(arg.meta["val"], torch.Tensor)
823        and torch.is_complex(arg.meta["val"])
824    ):
825        return True
826    elif isinstance(arg, list):
827        for item in arg:
828            return _is_arg_with_complex_dtype(item)
829    return False
830
831
832def _find_onnx_data_type(
833    torch_input: fx_type_utils.TensorLike
834    | str
835    | int
836    | float
837    | bool
838    | list
839    | tuple
840    | complex
841    | None,
842) -> set[str]:
843    """Convert inputs data type from torch acceptable dtype to the compatible onnx dtype string."""
844    if (
845        isinstance(torch_input, fx_type_utils.TensorLike)
846        and torch_input.dtype is not None
847    ):
848        return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(torch_input.dtype)
849    if isinstance(torch_input, (int, float, bool, str, complex)):
850        return fx_type_utils.from_torch_dtype_to_onnx_dtype_str(type(torch_input))
851    if isinstance(torch_input, (list, tuple)) and torch_input:  # [Tensor, Tensor]
852        the_first_non_none_item = next(
853            (item for item in torch_input if item is not None), None
854        )
855        set_dtype = _find_onnx_data_type(the_first_non_none_item)
856        if any(isinstance(input, fx_type_utils.TensorLike) for input in torch_input):
857            # NOTE: Any Tensor involved in a list would make it a seq(tensor(onnx_type))
858            return {f"seq({dtype})" for dtype in set_dtype}
859        else:
860            # constant list of non-tensor type
861            return set_dtype
862    if (
863        torch_input is None
864        or (
865            isinstance(torch_input, fx_type_utils.TensorLike)
866            and torch_input.dtype is None
867        )
868        or (isinstance(torch_input, (list, tuple)) and not torch_input)
869    ):
870        # NOTE: None, No dtype, and empty list are edge cases, we allow it to be any type to relax the type check
871        # seq(tensor) also goes to here, as it is not supported in torchscript, and it would be None in this case.
872        return set()
873
874    raise RuntimeError(f"Unknown input type from input: {torch_input}")
875