• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import inspect
3from typing import Dict, List, Union
4
5from torch import _C
6from torch.onnx import _constants
7from torch.onnx._internal import registration
8
9
10class _TorchSchema:
11    def __init__(self, schema: Union[_C.FunctionSchema, str]) -> None:
12        if isinstance(schema, _C.FunctionSchema):
13            self.name: str = schema.name
14            self.overload_name: str = schema.overload_name
15            self.arguments: List[str] = [arg.name for arg in schema.arguments]
16            self.optional_arguments: List[str] = []
17            self.returns: List[str] = [ret.name for ret in schema.returns]
18            self.opsets: List[int] = []
19        else:
20            self.name = schema
21            self.overload_name = ""
22            self.arguments = []
23            self.optional_arguments = []
24            self.returns = []
25            self.opsets = []
26
27    def __str__(self) -> str:
28        s = (
29            f"{self.name}.{self.overload_name}("
30            + ", ".join(self.arguments)
31            + ") -> ("
32            + ", ".join(self.returns)
33            + ")"
34            + " in opsets "
35            + ", ".join(str(opset) for opset in self.opsets)
36        )
37        return s
38
39    def __hash__(self):
40        # TODO(thiagocrepaldi): handle overload_name?
41        return hash(self.name)
42
43    def __eq__(self, other) -> bool:
44        if not isinstance(other, _TorchSchema):
45            return False
46        # TODO(thiagocrepaldi): handle overload_name?
47        return self.name == other.name
48
49    def is_aten(self) -> bool:
50        return self.name.startswith("aten::")
51
52    def is_backward(self) -> bool:
53        return "backward" in self.name
54
55
56def _symbolic_argument_count(func):
57    params = []
58    signature = inspect.signature(func)
59    optional_params = []
60    for name, parameter in signature.parameters.items():
61        if name in {"_outputs", "g"}:
62            continue
63        if parameter.default is parameter.empty:
64            optional_params.append(parameter)
65        else:
66            params.append(str(parameter))
67    return params
68
69
70def all_forward_schemas() -> Dict[str, _TorchSchema]:
71    """Returns schemas for all TorchScript forward ops."""
72    torch_schemas = [_TorchSchema(s) for s in _C._jit_get_all_schemas()]
73    return {schema.name: schema for schema in torch_schemas if not schema.is_backward()}
74
75
76def all_symbolics_schemas() -> Dict[str, _TorchSchema]:
77    """Returns schemas for all onnx supported ops."""
78    symbolics_schemas = {}
79
80    for name in registration.registry.all_functions():
81        func_group = registration.registry.get_function_group(name)
82        assert func_group is not None
83        symbolics_schema = _TorchSchema(name)
84        func = func_group.get(_constants.ONNX_MAX_OPSET)
85        if func is not None:
86            symbolics_schema.arguments = _symbolic_argument_count(func)
87            symbolics_schema.opsets = list(
88                range(func_group.get_min_supported(), _constants.ONNX_MAX_OPSET + 1)
89            )
90        else:
91            # Only support opset < 9
92            func = func_group.get(7)
93            symbolics_schema.arguments = _symbolic_argument_count(func)
94            symbolics_schema.opsets = list(range(7, _constants.ONNX_BASE_OPSET))
95
96        symbolics_schemas[name] = symbolics_schema
97
98    return symbolics_schemas
99