#!/usr/bin/env python3 from __future__ import annotations import argparse import os import yaml from torchgen.code_template import CodeTemplate from torchgen.selective_build.selector import SelectiveBuilder # Safely load fast C Yaml loader/dumper if they are available try: from yaml import CSafeLoader as Loader except ImportError: from yaml import SafeLoader as Loader # type: ignore[assignment, misc] if_condition_template_str = """if (kernel_tag_sv.compare("$kernel_tag_name") == 0) { return $dtype_checks; }""" if_condition_template = CodeTemplate(if_condition_template_str) selected_kernel_dtypes_h_template_str = """ #include #include #include namespace at { inline constexpr bool should_include_kernel_dtype( const char *kernel_tag_str, at::ScalarType scalar_type ) { c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str); $body return false; } } """ selected_kernel_dtypes_h_template = CodeTemplate(selected_kernel_dtypes_h_template_str) selected_mobile_ops_preamble = """#pragma once /** * Generated by gen_selected_mobile_ops_header.py */ """ def extract_root_operators(selective_builder: SelectiveBuilder) -> set[str]: ops = [] for op_name, op in selective_builder.operators.items(): if op.is_root_operator: ops.append(op_name) return set(ops) def get_selected_kernel_dtypes_code( selective_builder: SelectiveBuilder, ) -> str: # See https://www.internalfb.com/intern/paste/P153411698/ for an example of the # generated code in case all kernel dtypes are selected and in case some kernel # dtypes are selected (i.e. both cases). # body = "return true;" if ( selective_builder.include_all_operators is False and selective_builder.include_all_non_op_selectives is False ): body_parts = [] for kernel_tag, dtypes in selective_builder.kernel_metadata.items(): conditions = ["scalar_type == at::ScalarType::" + x for x in dtypes] body_parts.append( if_condition_template.substitute( kernel_tag_name=kernel_tag, dtype_checks=" || ".join(conditions), ), ) body = " else ".join(body_parts) header_contents = selected_kernel_dtypes_h_template.substitute(body=body) return header_contents # Write the file selected_mobile_ops.h with optionally: # 1. The selected root operators # 2. The selected kernel dtypes def write_selected_mobile_ops( output_file_path: str, selective_builder: SelectiveBuilder, ) -> None: root_ops = extract_root_operators(selective_builder) custom_classes = selective_builder.custom_classes build_features = selective_builder.build_features with open(output_file_path, "wb") as out_file: body_parts = [selected_mobile_ops_preamble] # This condition checks if we are in selective build. # if these lists are not defined the corresponding selective build macros trivially return the item in question was selected if not selective_builder.include_all_operators: body_parts.append( "#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n" ) # This condition checks if we are in tracing based selective build if selective_builder.include_all_non_op_selectives is False: body_parts.append( "#define TORCH_CUSTOM_CLASS_ALLOWLIST " + (";".join(sorted(custom_classes))) + ";\n\n" ) body_parts.append( "#define TORCH_BUILD_FEATURE_ALLOWLIST " + (";".join(sorted(build_features))) + ";\n\n" ) body_parts.append(get_selected_kernel_dtypes_code(selective_builder)) header_contents = "".join(body_parts) out_file.write(header_contents.encode("utf-8")) # root_ops: a set of selected root operators for selective build # Write the file selected_mobile_ops.h with optionally: # 1. The selected root operators from root_ops # 2. All kernel dtypes def write_selected_mobile_ops_with_all_dtypes( output_file_path: str, root_ops: set[str], ) -> None: with open(output_file_path, "wb") as out_file: body_parts = [selected_mobile_ops_preamble] body_parts.append( "#define TORCH_OPERATOR_WHITELIST " + (";".join(sorted(root_ops))) + ";\n\n" ) selective_builder = SelectiveBuilder.get_nop_selector() body_parts.append(get_selected_kernel_dtypes_code(selective_builder)) header_contents = "".join(body_parts) out_file.write(header_contents.encode("utf-8")) def main() -> None: parser = argparse.ArgumentParser( description="Generate selected_mobile_ops.h for selective build." ) parser.add_argument( "-p", "--yaml-file-path", "--yaml_file_path", type=str, required=True, help="Path to the yaml file with a list of operators used by the model.", ) parser.add_argument( "-o", "--output-file-path", "--output_file_path", type=str, required=True, help="Path to destination" "folder where selected_mobile_ops.h will be written.", ) parsed_args = parser.parse_args() model_file_name = parsed_args.yaml_file_path print("Loading yaml file: ", model_file_name) loaded_model = {} with open(model_file_name, "rb") as model_file: loaded_model = yaml.load(model_file, Loader=Loader) root_operators_set = set(loaded_model) print("Writing header file selected_mobile_ops.h: ", parsed_args.output_file_path) write_selected_mobile_ops_with_all_dtypes( os.path.join(parsed_args.output_file_path, "selected_mobile_ops.h"), root_operators_set, ) if __name__ == "__main__": main()