• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2This script generates a CSV table with all ATen operators
3supported by `torch.onnx.export`. The generated table is included by
4docs/source/onnx_supported_aten_list.rst.
5"""
6
7import os
8
9from torch.onnx import _onnx_supported_ops
10
11
12# Constants
13BUILD_DIR = "build/onnx"
14SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv"
15UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv"
16
17
18def _sort_key(namespaced_opname):
19    return tuple(reversed(namespaced_opname.split("::")))
20
21
22def _get_op_lists():
23    all_schemas = _onnx_supported_ops.all_forward_schemas()
24    symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas()
25    supported_result = set()
26    not_supported_result = set()
27    for opname in all_schemas:
28        if opname.endswith("_"):
29            opname = opname[:-1]
30        if opname in symbolic_schemas:
31            # Supported op
32            opsets = symbolic_schemas[opname].opsets
33            supported_result.add((opname, f"Since opset {opsets[0]}"))
34        else:
35            # Unsupported op
36            not_supported_result.add((opname, "Not yet supported"))
37    return (
38        sorted(supported_result, key=lambda x: _sort_key(x[0])),
39        sorted(not_supported_result),
40    )
41
42
43def main():
44    os.makedirs(BUILD_DIR, exist_ok=True)
45
46    supported, unsupported = _get_op_lists()
47
48    with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f:
49        f.write("Operator,opset_version(s)\n")
50        for name, opset_version in supported:
51            f.write(f'"``{name}``","{opset_version}"\n')
52
53    with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f:
54        f.write("Operator,opset_version(s)\n")
55        for name, opset_version in unsupported:
56            f.write(f'"``{name}``","{opset_version}"\n')
57
58
59if __name__ == "__main__":
60    main()
61