# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import argparse import os from typing import Dict from executorch.backends.vulkan.test.op_tests.cases import test_suites from executorch.backends.vulkan.test.op_tests.utils.gen_benchmark_vk import ( VkBenchmarkFileGen, ) from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import ( ComputeGraphGen, ) from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite from torchgen import local from torchgen.gen import parse_native_yaml, ParsedYaml from torchgen.model import DispatchKey, NativeFunction def registry_name(f: NativeFunction) -> str: name = str(f.namespace) + "." + str(f.func.name) if len(f.func.name.overload_name) == 0: name += ".default" return name def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]: f_map: Dict[str, NativeFunction] = {} for f in parsed_yaml.native_functions: f_map[registry_name(f)] = f return f_map def process_test_suites( cpp_generator: VkBenchmarkFileGen, f_map: Dict[str, NativeFunction], test_suites: Dict[str, TestSuite], ) -> None: for registry_name, op_test_suites in test_suites.items(): f = f_map[registry_name] if isinstance(op_test_suites, list): for suite in op_test_suites: cpp_generator.add_suite(registry_name, f, suite) else: cpp_generator.add_suite(registry_name, f, op_test_suites) @local.parametrize( use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False ) def generate_cpp( native_functions_yaml_path: str, tags_path: str, output_dir: str ) -> None: output_file = os.path.join(output_dir, "op_benchmarks.cpp") cpp_generator = VkBenchmarkFileGen(output_file) parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path) f_map = construct_f_map(parsed_yaml) ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU] process_test_suites(cpp_generator, f_map, test_suites) with open(output_file, "w") as file: file.write(cpp_generator.generate_cpp()) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--aten-yaml-path", help="path to native_functions.yaml file.", ) parser.add_argument( "--tags-path", help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.", ) parser.add_argument("-o", "--output", help="Output directory", required=True) args = parser.parse_args() generate_cpp(args.aten_yaml_path, args.tags_path, args.output)