1# mypy: allow-untyped-defs 2import inspect 3from typing import Dict, List, Union 4 5from torch import _C 6from torch.onnx import _constants 7from torch.onnx._internal import registration 8 9 10class _TorchSchema: 11 def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None: 12 if isinstance(schema, _C.FunctionSchema): 13 self.name: str = schema.name 14 self.overload_name: str = schema.overload_name 15 self.arguments: List[str] = [arg.name for arg in schema.arguments] 16 self.optional_arguments: List[str] = [] 17 self.returns: List[str] = [ret.name for ret in schema.returns] 18 self.opsets: List[int] = [] 19 else: 20 self.name = schema 21 self.overload_name = "" 22 self.arguments = [] 23 self.optional_arguments = [] 24 self.returns = [] 25 self.opsets = [] 26 27 def __str__(self) -> str: 28 s = ( 29 f"{self.name}.{self.overload_name}(" 30 + ", ".join(self.arguments) 31 + ") -> (" 32 + ", ".join(self.returns) 33 + ")" 34 + " in opsets " 35 + ", ".join(str(opset) for opset in self.opsets) 36 ) 37 return s 38 39 def __hash__(self): 40 # TODO(thiagocrepaldi): handle overload_name? 41 return hash(self.name) 42 43 def __eq__(self, other) -> bool: 44 if not isinstance(other, _TorchSchema): 45 return False 46 # TODO(thiagocrepaldi): handle overload_name? 47 return self.name == other.name 48 49 def is_aten(self) -> bool: 50 return self.name.startswith("aten::") 51 52 def is_backward(self) -> bool: 53 return "backward" in self.name 54 55 56def _symbolic_argument_count(func): 57 params = [] 58 signature = inspect.signature(func) 59 optional_params = [] 60 for name, parameter in signature.parameters.items(): 61 if name in {"_outputs", "g"}: 62 continue 63 if parameter.default is parameter.empty: 64 optional_params.append(parameter) 65 else: 66 params.append(str(parameter)) 67 return params 68 69 70def all_forward_schemas() -> Dict[str, _TorchSchema]: 71 """Returns schemas for all TorchScript forward ops.""" 72 torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()] 73 return {schema.name: schema for schema in torch_schemas if not schema.is_backward()} 74 75 76def all_symbolics_schemas() -> Dict[str, _TorchSchema]: 77 """Returns schemas for all onnx supported ops.""" 78 symbolics_schemas = {} 79 80 for name in registration.registry.all_functions(): 81 func_group = registration.registry.get_function_group(name) 82 assert func_group is not None 83 symbolics_schema = _TorchSchema(name) 84 func = func_group.get(_constants.ONNX_MAX_OPSET) 85 if func is not None: 86 symbolics_schema.arguments = _symbolic_argument_count(func) 87 symbolics_schema.opsets = list( 88 range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1) 89 ) 90 else: 91 # Only support opset < 9 92 func = func_group.get(7) 93 symbolics_schema.arguments = _symbolic_argument_count(func) 94 symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET)) 95 96 symbolics_schemas[name] = symbolics_schema 97 98 return symbolics_schemas 99