#!/usr/bin/env python3 from __future__ import annotations import argparse import json import os import sys from functools import reduce from typing import Any import yaml from tools.lite_interpreter.gen_selected_mobile_ops_header import ( write_selected_mobile_ops, ) from torchgen.selective_build.selector import ( combine_selective_builders, SelectiveBuilder, ) def extract_all_operators(selective_builder: SelectiveBuilder) -> set[str]: return set(selective_builder.operators.keys()) def extract_training_operators(selective_builder: SelectiveBuilder) -> set[str]: ops = [] for op_name, op in selective_builder.operators.items(): if op.is_used_for_training: ops.append(op_name) return set(ops) def throw_if_any_op_includes_overloads(selective_builder: SelectiveBuilder) -> None: ops = [] for op_name, op in selective_builder.operators.items(): if op.include_all_overloads: ops.append(op_name) if ops: raise Exception( # noqa: TRY002 ( "Operators that include all overloads are " + "not allowed since --allow-include-all-overloads " + "was specified: {}" ).format(", ".join(ops)) ) def gen_supported_mobile_models(model_dicts: list[Any], output_dir: str) -> None: supported_mobile_models_source = """/* * Generated by gen_oplist.py */ #include "fb/supported_mobile_models/SupportedMobileModels.h" struct SupportedMobileModelCheckerRegistry {{ SupportedMobileModelCheckerRegistry() {{ auto& ref = facebook::pytorch::supported_model::SupportedMobileModelChecker::singleton(); ref.set_supported_md5_hashes(std::unordered_set{{ {supported_hashes_template} }}); }} }}; // This is a global object, initializing which causes the registration to happen. SupportedMobileModelCheckerRegistry register_model_versions; """ # Generate SupportedMobileModelsRegistration.cpp md5_hashes = set() for model_dict in model_dicts: if "debug_info" in model_dict: debug_info = json.loads(model_dict["debug_info"][0]) if debug_info["is_new_style_rule"]: for asset_info in debug_info["asset_info"].values(): md5_hashes.update(asset_info["md5_hash"]) supported_hashes = "" for md5 in md5_hashes: supported_hashes += f'"{md5}",\n' with open( os.path.join(output_dir, "SupportedMobileModelsRegistration.cpp"), "wb" ) as out_file: source = supported_mobile_models_source.format( supported_hashes_template=supported_hashes ) out_file.write(source.encode("utf-8")) def main(argv: list[Any]) -> None: """This binary generates 3 files: 1. selected_mobile_ops.h: Primary operators used by templated selective build and Kernel Function dtypes captured by tracing 2. selected_operators.yaml: Selected root and non-root operators (either via tracing or static analysis) """ parser = argparse.ArgumentParser(description="Generate operator lists") parser.add_argument( "--output-dir", "--output_dir", help=( "The directory to store the output yaml files (selected_mobile_ops.h, " + "selected_kernel_dtypes.h, selected_operators.yaml)" ), required=True, ) parser.add_argument( "--model-file-list-path", "--model_file_list_path", help=( "Path to a file that contains the locations of individual " + "model YAML files that contain the set of used operators. This " + "file path must have a leading @-symbol, which will be stripped " + "out before processing." ), required=True, ) parser.add_argument( "--allow-include-all-overloads", "--allow_include_all_overloads", help=( "Flag to allow operators that include all overloads. " + "If not set, operators registered without using the traced style will" + "break the build." ), action="store_true", default=False, required=False, ) options = parser.parse_args(argv) if os.path.isfile(options.model_file_list_path): print("Processing model file: ", options.model_file_list_path) model_dicts = [] model_dict = yaml.safe_load(open(options.model_file_list_path)) model_dicts.append(model_dict) else: print("Processing model directory: ", options.model_file_list_path) assert options.model_file_list_path[0] == "@" model_file_list_path = options.model_file_list_path[1:] model_dicts = [] with open(model_file_list_path) as model_list_file: model_file_names = model_list_file.read().split() for model_file_name in model_file_names: with open(model_file_name, "rb") as model_file: model_dict = yaml.safe_load(model_file) model_dicts.append(model_dict) selective_builders = [SelectiveBuilder.from_yaml_dict(m) for m in model_dicts] # While we have the model_dicts generate the supported mobile models api gen_supported_mobile_models(model_dicts, options.output_dir) # We may have 0 selective builders since there may not be any viable # pt_operator_library rule marked as a dep for the pt_operator_registry rule. # This is potentially an error, and we should probably raise an assertion # failure here. However, this needs to be investigated further. selective_builder = SelectiveBuilder.from_yaml_dict({}) if len(selective_builders) > 0: selective_builder = reduce( combine_selective_builders, selective_builders, ) if not options.allow_include_all_overloads: throw_if_any_op_includes_overloads(selective_builder) with open( os.path.join(options.output_dir, "selected_operators.yaml"), "wb" ) as out_file: out_file.write( yaml.safe_dump( selective_builder.to_dict(), default_flow_style=False ).encode("utf-8"), ) write_selected_mobile_ops( os.path.join(options.output_dir, "selected_mobile_ops.h"), selective_builder, ) if __name__ == "__main__": main(sys.argv[1:])