1"""Module for handling ATen to ONNX functions registration.""" 2 3from __future__ import annotations 4 5import dataclasses 6from typing import TYPE_CHECKING 7 8 9# We can only import onnx from this module in a type-checking context to ensure that 10# 'import torch.onnx' continues to work without having 'onnx' installed. We fully 11# 'import onnx' inside of dynamo_export (by way of _assert_dependencies). 12if TYPE_CHECKING: 13 import types 14 15 import onnxscript # type: ignore[import] 16 17 import torch._ops 18 19 20@dataclasses.dataclass(frozen=True, eq=True) 21class ONNXFunction: 22 """A wrapper of onnx-script function. 23 24 op_full_name: The qualified name of the function. In the form of '<namespace>::<op_name>.<overload>'. 25 onnx_function: The onnx-script function from torchlib. 26 is_custom: Whether the function is a custom function. 27 is_complex: Whether the function is a function that handles complex valued inputs. 28 29 """ 30 31 onnx_function: onnxscript.OnnxFunction | onnxscript.TracedOnnxFunction 32 op_full_name: str 33 is_custom: bool = False 34 is_complex: bool = False 35 36 37@dataclasses.dataclass(frozen=True, eq=True) 38class OpName: 39 """A class representing an operator name in internal ONNX converter.""" 40 41 namespace: str 42 op_name: str 43 overload: str 44 45 @classmethod 46 def from_name_parts( 47 cls, namespace: str, op_name: str, overload: str | None = None 48 ) -> OpName: 49 # NOTE: in PyTorch, the overload could be unprovided to indicate the 50 # default overload 51 if overload is None or overload == "": 52 overload = "default" 53 return cls(namespace, op_name, overload) 54 55 @classmethod 56 def from_qualified_name(cls, qualified_name: str) -> OpName: 57 """When the name is <namespace>::<op_name>[.<overload>]""" 58 namespace, opname_overload = qualified_name.split("::") 59 op_name, *overload = opname_overload.split(".", 1) 60 overload = overload[0] if overload else "default" 61 return cls(namespace, op_name, overload) 62 63 @classmethod 64 def from_op_overload(cls, op_overload: torch._ops.OpOverload) -> OpName: 65 return cls.from_qualified_name(op_overload.name()) 66 67 @classmethod 68 def from_builtin_function( 69 cls, builtin_function: types.BuiltinFunctionType 70 ) -> OpName: 71 """From a builtin function, e.g. operator.add, math.ceil, etc, get the OpName. 72 73 FX graph uses built-in functions to caculate sympy expression. This function 74 is used to get the OpName from a builtin function. 75 76 Args: 77 builtin_function (types.BuiltinFunctionType): operator.add, math.ceil, etc. 78 79 Returns: 80 OpName: _description_ 81 """ 82 op = builtin_function.__name__ # add, sub, etc. 83 module = builtin_function.__module__ # _operators or math 84 return cls.from_qualified_name(module + "::" + op) 85 86 def qualified_name(self) -> str: 87 return f"{self.namespace}::{self.op_name}.{self.overload}" 88