• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Module for handling ATen to ONNX functions registration."""
2
3from __future__ import annotations
4
5import dataclasses
6from typing import TYPE_CHECKING
7
8
9# We can only import onnx from this module in a type-checking context to ensure that
10# 'import torch.onnx' continues to work without having 'onnx' installed. We fully
11# 'import onnx' inside of dynamo_export (by way of _assert_dependencies).
12if TYPE_CHECKING:
13    import types
14
15    import onnxscript  # type: ignore[import]
16
17    import torch._ops
18
19
20@dataclasses.dataclass(frozen=True, eq=True)
21class ONNXFunction:
22    """A wrapper of onnx-script function.
23
24    op_full_name: The qualified name of the function. In the form of '<namespace>::<op_name>.<overload>'.
25    onnx_function: The onnx-script function from torchlib.
26    is_custom: Whether the function is a custom function.
27    is_complex: Whether the function is a function that handles complex valued inputs.
28
29    """
30
31    onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction
32    op_full_name: str
33    is_custom: bool = False
34    is_complex: bool = False
35
36
37@dataclasses.dataclass(frozen=True, eq=True)
38class OpName:
39    """A class representing an operator name in internal ONNX converter."""
40
41    namespace: str
42    op_name: str
43    overload: str
44
45    @classmethod
46    def from_name_parts(
47        cls, namespace: str, op_name: str, overload: str | None = None
48    ) -> OpName:
49        # NOTE: in PyTorch, the overload could be unprovided to indicate the
50        # default overload
51        if overload is None or overload == "":
52            overload = "default"
53        return cls(namespace, op_name, overload)
54
55    @classmethod
56    def from_qualified_name(cls, qualified_name: str) -> OpName:
57        """When the name is <namespace>::<op_name>[.<overload>]"""
58        namespace, opname_overload = qualified_name.split("::")
59        op_name, *overload = opname_overload.split(".", 1)
60        overload = overload[0] if overload else "default"
61        return cls(namespace, op_name, overload)
62
63    @classmethod
64    def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName:
65        return cls.from_qualified_name(op_overload.name())
66
67    @classmethod
68    def from_builtin_function(
69        cls, builtin_function: types.BuiltinFunctionType
70    ) -> OpName:
71        """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName.
72
73        FX graph uses built-in functions to caculate sympy expression. This function
74        is used to get the OpName from a builtin function.
75
76        Args:
77            builtin_function (types.BuiltinFunctionType): operator.add, math.ceil, etc.
78
79        Returns:
80            OpName: _description_
81        """
82        op = builtin_function.__name__  # add, sub, etc.
83        module = builtin_function.__module__  # _operators or math
84        return cls.from_qualified_name(module + "::" + op)
85
86    def qualified_name(self) -> str:
87        return f"{self.namespace}::{self.op_name}.{self.overload}"
88