#!/usr/bin/env fbpython # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. # pyre-unsafe import argparse import os import sys from typing import Any, List import yaml from torchgen.code_template import CodeTemplate ops_and_dtypes_template_str = """((exec_aten::string_view(operator_name).compare("$operator_name") == 0)\n && ($dtype_checks))""" ops_and_dtypes_template = CodeTemplate(ops_and_dtypes_template_str) selected_kernel_dtypes_h_template_str = """#pragma once /** * Generated by executorch/codegen/tools/gen_selected_op_variants.py */ inline constexpr bool should_include_kernel_dtype( const char *operator_name, exec_aten::ScalarType scalar_type ) { return $body; } """ selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str) # enum from: https://github.com/pytorch/executorch/blob/main/runtime/core/portable_type/scalar_type.h dtype_enum_to_type = { "0": "Byte", "1": "Char", "2": "Short", "3": "Int", "4": "Long", "5": "Half", "6": "Float", "7": "Double", "8": "ComplexHalf", "9": "ComplexFloat", "10": "ComplexDouble", "11": "Bool", "12": "QInt8", "13": "QUInt8", "14": "QInt32", "15": "BFloat16", "16": "QUInt4x2", "17": "QUInt2x4", "18": "Bits1x8", "19": "Bits2x4", "20": "Bits4x2", "21": "Bits8", "22": "Bits16", } def write_selected_op_variants(yaml_file_path: str, output_dir: str) -> None: with open(yaml_file_path, "r") as selected_operators_file: # Collect et_kernel_metadata from selected_operators.yaml and extract dtypes # Example format: v1/6;0,1|6;0,1|6;0,1|6;0,1 # Float, 0, 1 selected_operators_dict = yaml.safe_load(selected_operators_file) et_kernel_metadata = selected_operators_dict.get("et_kernel_metadata", {}) assert isinstance(et_kernel_metadata, dict) body = "true" body_parts = [] for operator_name, kernel_metadata_str in et_kernel_metadata.items(): tensor_meta = [] for kernel_metadata in kernel_metadata_str: if kernel_metadata == "default" or "/" not in kernel_metadata: break else: x = kernel_metadata.split("/")[1] tensor_meta.extend(x.split("|")) conditions = ["true"] if len(tensor_meta) > 0: dtype_set = set([x.split(";")[0] for x in tensor_meta]) dtype_list = sorted([dtype_enum_to_type[x] for x in dtype_set]) conditions = [ "scalar_type == exec_aten::ScalarType::" + x for x in dtype_list ] body_parts.append( ops_and_dtypes_template.substitute( operator_name=operator_name.replace("aten::", ""), dtype_checks=" || ".join(conditions), ), ) body = "\n || ".join(body_parts) header_contents = selected_kernel_dtypes_h_template.substitute(body=body) selected_op_variants_path = os.path.join(output_dir, "selected_op_variants.h") with open(selected_op_variants_path, "wb") as out_file: out_file.write(header_contents.encode("utf-8")) def main(argv: List[Any]) -> None: parser = argparse.ArgumentParser(description="Generate operator lists") parser.add_argument( "--yaml-file-path", "--yaml_file_path", help=("The directory where selected_operators.yaml was generated)"), required=True, ) parser.add_argument( "--output-dir", "--output_dir", help=( "The directory to store the output yaml files (selected_op_variants.h, " + "selected_kernel_dtypes.h, selected_operators.yaml)" ), required=True, ) options = parser.parse_args(argv) write_selected_op_variants(options.yaml_file_path, options.output_dir) if __name__ == "__main__": main(sys.argv[1:])