1""" 2This script generates a CSV table with all ATen operators 3supported by `torch.onnx.export`. The generated table is included by 4docs/source/onnx_supported_aten_list.rst. 5""" 6 7import os 8 9from torch.onnx import _onnx_supported_ops 10 11 12# Constants 13BUILD_DIR = "build/onnx" 14SUPPORTED_OPS_CSV_FILE = "auto_gen_supported_op_list.csv" 15UNSUPPORTED_OPS_CSV_FILE = "auto_gen_unsupported_op_list.csv" 16 17 18def _sort_key(namespaced_opname): 19 return tuple(reversed(namespaced_opname.split("::"))) 20 21 22def _get_op_lists(): 23 all_schemas = _onnx_supported_ops.all_forward_schemas() 24 symbolic_schemas = _onnx_supported_ops.all_symbolics_schemas() 25 supported_result = set() 26 not_supported_result = set() 27 for opname in all_schemas: 28 if opname.endswith("_"): 29 opname = opname[:-1] 30 if opname in symbolic_schemas: 31 # Supported op 32 opsets = symbolic_schemas[opname].opsets 33 supported_result.add((opname, f"Since opset {opsets[0]}")) 34 else: 35 # Unsupported op 36 not_supported_result.add((opname, "Not yet supported")) 37 return ( 38 sorted(supported_result, key=lambda x: _sort_key(x[0])), 39 sorted(not_supported_result), 40 ) 41 42 43def main(): 44 os.makedirs(BUILD_DIR, exist_ok=True) 45 46 supported, unsupported = _get_op_lists() 47 48 with open(os.path.join(BUILD_DIR, SUPPORTED_OPS_CSV_FILE), "w") as f: 49 f.write("Operator,opset_version(s)\n") 50 for name, opset_version in supported: 51 f.write(f'"``{name}``","{opset_version}"\n') 52 53 with open(os.path.join(BUILD_DIR, UNSUPPORTED_OPS_CSV_FILE), "w") as f: 54 f.write("Operator,opset_version(s)\n") 55 for name, opset_version in unsupported: 56 f.write(f'"``{name}``","{opset_version}"\n') 57 58 59if __name__ == "__main__": 60 main() 61