# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import re from dataclasses import dataclass from typing import List, Optional, Union from executorch.backends.vulkan.test.op_tests.utils.aten_types import ( AT_INT_ARRAY_REF, AT_SCALAR, AT_TENSOR, AT_TENSOR_LIST, BOOL, DOUBLE, INT, OPT_AT_DOUBLE_ARRAY_REF, OPT_AT_INT_ARRAY_REF, OPT_AT_TENSOR, OPT_BOOL, OPT_DEVICE, OPT_INT64, OPT_LAYOUT, OPT_MEMORY_FORMAT, OPT_SCALAR_TYPE, STRING, TENSOR_VECTOR, THREE_TENSOR_TUPLE, TWO_TENSOR_TUPLE, ) from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup from torchgen.gen import generate_static_dispatch_backend_call, translate_args from torchgen.gen_aoti_c_shim import gen_static_dispatch_backend_call_signature from torchgen.model import NativeFunction, Variant ################################### ## Compute Graph Code Generation ## ################################### @dataclass class ATenArg: name: str cpp_type: str default: Optional[str] @dataclass class ValueRef: name: str src_cpp_name: str src_cpp_type: str is_in: bool = False is_out: bool = False requires_prepack: bool = False supports_prepack: bool = False # When is_dynamic_size is true, the underlying object size is not known # during code-gen. Example is the out value for aten.split where the out # value is a vector. In these cases, we need to use an additional # vector or at::TensorList to track these values. is_dynamic_size: bool = False @property def io_value_list_name(self): assert self.is_dynamic_size return f"{self.name}_io_value_list" @property def value_list_name(self): assert self.is_dynamic_size return f"{self.name}_value_list" @property def vk_out(self): assert self.is_out return f"vk_{self.name}" ValueRefList = Union[ValueRef, List[ValueRef]] InableCppType = frozenset([AT_TENSOR, AT_TENSOR_LIST]) class ComputeGraphGen: backend_key = None def __init__(self, op_reg_name: str, f: NativeFunction, suite_def: TestSuite): self.op_reg_name = op_reg_name self.f = f self.suite_def = suite_def self.f_sig = CppSignatureGroup.from_native_function( self.f, method=False, fallback_binding=self.f.manual_cpp_binding ).most_faithful_signature() self.graph = "graph" self.dot = "->" self.args = [] self.refs = {} self.should_prepack = False for binding in self.f_sig.arguments(): arg = binding.argument ctype = cpp.argumenttype_type( arg.type, mutable=arg.is_write, binds=arg.name ) cpp_type = ctype.cpp_type(strip_ref=True) self.args.append( ATenArg(name=arg.name, cpp_type=cpp_type, default=arg.default) ) # These are the argument will be passed as a "weight" tensor, the # corresponding object will be TensorRef in the compute graph. requires_prepack = ( "weight" in arg.name or "bias" in arg.name or "running_mean" in arg.name or "running_var" in arg.name ) supports_prepack = False if arg.name in self.suite_def.prepacked_args: supports_prepack = True self.refs[arg.name] = ValueRef( name=f"{arg.name}_ref", src_cpp_name=arg.name, src_cpp_type=cpp_type, is_in=(cpp_type in InableCppType), requires_prepack=requires_prepack, supports_prepack=supports_prepack, ) ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type() self.out = ATenArg(name="out", cpp_type=ret_type, default=None) if ret_type == AT_TENSOR: self.refs["out"] = ValueRef( name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True ) elif ret_type == TWO_TENSOR_TUPLE: self.refs["out"] = [ ValueRef( name="out_ref_first", src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, ), ValueRef( name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=False, ), ] elif ret_type == THREE_TENSOR_TUPLE: self.refs["out"] = [ ValueRef( name="out_ref_first", src_cpp_name="std::get<0>(out)", src_cpp_type="at::Tensor", is_out=True, ), ValueRef( name="out_ref_second", src_cpp_name="std::get<1>(out)", src_cpp_type="at::Tensor", is_out=True, ), ValueRef( name="out_ref_third", src_cpp_name="std::get<2>(out)", src_cpp_type="at::Tensor", is_out=True, ), ValueRef( name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=False, ), ] elif ret_type == TENSOR_VECTOR: self.refs["out"] = ValueRef( name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True, is_dynamic_size=True, ) else: raise NotImplementedError( f"ret_type: {ret_type} not supported for out value" ) ## ATen code generation def gen_decl(self, fn_name: str, ret_type: str = "void") -> str: cpp_args = [a.decl() for a in self.f_sig.arguments()] cpp_args_str = ", ".join(cpp_args) return f"{ret_type} {fn_name}({cpp_args_str})" def create_aten_fn_call(self) -> str: func_call = generate_static_dispatch_backend_call( self.f_sig, self.f, ComputeGraphGen.backend_key )[7:].replace("::cpu", "") return func_call def create_aten_method_call(self) -> str: # For functions with only Method variant, we fallback to the function # declared in MethodOperators.h. The method is declared as # at::_ops::{name}::call(*), and ATEN_FN is a handly macro. cpp_sig = gen_static_dispatch_backend_call_signature(self.f_sig, self.f) exprs = translate_args(self.f_sig, cpp_sig) func_call = f"ATEN_FN({self.f_sig.name()})({exprs});" return func_call def create_out_src(self, include_declarations: bool = True) -> str: cpp_type = self.out.cpp_type if include_declarations else "" if Variant.function in self.f.variants: return f"{cpp_type} out = " + self.create_aten_fn_call() + "\n" else: return f"{cpp_type} out = " + self.create_aten_method_call() + "\n" ## Graph code generation utils def prepack_ref(self, ref: ValueRef) -> bool: if ref.requires_prepack: return True else: return ref.supports_prepack and self.should_prepack def create_value_decl_for(self, ref: ValueRefList) -> str: # noqa: C901 if isinstance(ref, list): ret_str = "" for r in ref: ret_str += self.create_value_decl_for(r) return ret_str cpp_type = "IOValueRef" if (ref.is_in or ref.requires_prepack) else "ValueRef" if ref.src_cpp_type == AT_TENSOR_LIST: ret_str = f"std::vector {ref.name}_io_value_refs;\n" ret_str += f"std::vector {ref.name}_value_refs;\n" return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: ret_str = f"std::vector {ref.io_value_list_name};\n" ret_str += f"std::vector {ref.value_list_name};\n" return ret_str else: return f"{cpp_type} {ref.name};\n" def create_value_for( # noqa: C901 self, ref: ValueRefList, include_declarations: bool = True ) -> str: if isinstance(ref, list): ret_str = "" for r in ref: ret_str += self.create_value_for(r) return ret_str prepack = self.prepack_ref(ref) ref_is_view = self.suite_def.is_view_op and ref.is_out cpp_type = "IOValueRef" if (ref.is_in and not prepack) else "ValueRef" if not include_declarations: cpp_type = "" if ref.src_cpp_type == OPT_AT_TENSOR: ret_str = f"{cpp_type} {ref.name} = " if prepack: ret_str = "" if include_declarations: ret_str += f"IOValueRef {ref.name};\n" ret_str += f"{ref.name}.value = " ret_str += f"!{ref.src_cpp_name}.has_value() ? " ret_str += f"{self.graph}{self.dot}add_none() : " if not prepack: ret_str += f"{self.graph}{self.dot}" ret_str += "add_input_tensor(" if ref.is_in else "add_tensor(" ret_str += f"{ref.src_cpp_name}->sizes().vec(), " ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n" elif prepack: ret_str += f"{self.graph}{self.dot}" ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), " ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()), " ret_str += f"{ref.src_cpp_name}->const_data_ptr()); \n" return ret_str elif ref.src_cpp_type == OPT_INT64: ret_str = f"{cpp_type} {ref.name} = " ret_str += f"!{ref.src_cpp_name}.has_value() ? " ret_str += f"{self.graph}{self.dot}add_none() : " ret_str += f"{self.graph}{self.dot}add_scalar" ret_str += f"({ref.src_cpp_name}.value());\n" return ret_str elif ( ref.src_cpp_type == OPT_AT_DOUBLE_ARRAY_REF or ref.src_cpp_type == OPT_AT_INT_ARRAY_REF ): ret_str = f"{cpp_type} {ref.name} = " ret_str += f"!{ref.src_cpp_name}.has_value() ? " ret_str += f"{self.graph}{self.dot}add_none() : " ret_str += f"{self.graph}{self.dot}add_scalar_list" ret_str += f"({ref.src_cpp_name}->vec());\n" return ret_str elif ref.src_cpp_type == AT_TENSOR_LIST: assert ref.is_in, "AT_TENSOR_LIST must be an input" # This logic is a bit convoluted. We need to create a IOValueRef for # each tensor, to facilate staging. On the other hand, we will # use the .value tensor to create a ValueList, which will be passed # to the corresponding ops. ret_str = "" if include_declarations: ret_str += f"std::vector {ref.name}_io_value_refs;\n" ret_str += f"std::vector {ref.name}_value_refs;\n" ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n" ret_str += ( f" IOValueRef io_value_ref = {self.graph}{self.dot}add_input_tensor(\n" ) ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n" ret_str += ( f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n" ) ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n" ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n" ret_str += "}\n" ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n" return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: ret_str = "" if include_declarations: ret_str += f"std::vector {ref.io_value_list_name};\n" ret_str += f"std::vector {ref.value_list_name};\n" ret_str += f""" for (int i=0; i({ref.src_cpp_name}.toDouble()); \n" elif ref.src_cpp_type == AT_INT_ARRAY_REF: ret_str += f"add_scalar_list({ref.src_cpp_name}.vec()); \n" elif ref.src_cpp_type == BOOL: ret_str += f"add_scalar({ref.src_cpp_name}); \n" elif ref.src_cpp_type == INT: ret_str += f"add_scalar({ref.src_cpp_name}); \n" elif ref.src_cpp_type == DOUBLE: ret_str += f"add_scalar({ref.src_cpp_name}); \n" elif ( ref.src_cpp_type == OPT_SCALAR_TYPE or ref.src_cpp_type == OPT_LAYOUT or ref.src_cpp_type == OPT_DEVICE or ref.src_cpp_type == OPT_BOOL or ref.src_cpp_type == OPT_MEMORY_FORMAT ): ret_str += "add_none(); \n" elif ref.src_cpp_type == STRING: ret_str += f"add_string(std::string({ref.src_cpp_name})); \n" elif ref.src_cpp_type == TWO_TENSOR_TUPLE: ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second}}); \n" elif ref.src_cpp_type == THREE_TENSOR_TUPLE: ret_str += f"add_value_list({{{ref.name}_first, {ref.name}_second, {ref.name}_third}}); \n" else: raise RuntimeError(f"Unsupported cpp type {ref.src_cpp_type}") return ret_str def create_op_call(self) -> str: deref = "*" if self.dot == "->" else "" op_create_code = f'VK_GET_OP_FN("{self.op_reg_name}")({deref}{self.graph}, {{' for aten_arg in self.args: ref = self.refs[aten_arg.name] if ref.src_cpp_type == AT_TENSOR_LIST: # Special case. Underlying tensors are input tensors, but the # container itself is just a normal value. op_create_code += f"{ref.name}, " else: op_create_code += ( f"{ref.name}.value, " if ref.is_in or ref.requires_prepack or ref.is_out else f"{ref.name}, " ) # op_create_code += f"{ref.name}, " op_create_code += "out_ref});\n" return op_create_code def gen_output_staging_valueref_decl(self, ref: ValueRefList) -> str: if isinstance(ref, list): ret_str = "" for r in ref[:-1]: ret_str += self.gen_output_staging_valueref_decl(r) return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: assert ref.is_out ret_str = "" return ret_str assert ref.src_cpp_type == AT_TENSOR and ref.is_out return f"ValueRef {ref.name}_staging;\n" def set_output(self, ref: ValueRefList, include_declarations: bool = True) -> str: if isinstance(ref, list): ret_str = "" for r in ref[:-1]: ret_str += self.set_output(r, include_declarations) return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: assert ref.is_out ret_str = f""" for (int i=0; i str: assert isinstance(ref, ValueRef) assert ref.src_cpp_type in InableCppType and ref.is_in if self.prepack_ref(ref): return "" if ref.src_cpp_type == AT_TENSOR: ret_str = f"{self.graph}{self.dot}get_tensor({ref.name}.value)" ret_str += f"->virtual_resize({ref.src_cpp_name}.sizes().vec());\n" elif ref.src_cpp_type == AT_TENSOR_LIST: ret_str = "" ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" ret_str += ( f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)" ) ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n" ret_str += "}\n" else: raise AssertionError(f"{ref.src_cpp_type} not expected") return ret_str def copy_into_staging(self, ref: ValueRefList) -> str: assert isinstance(ref, ValueRef) assert ref.src_cpp_type in InableCppType and ref.is_in if self.prepack_ref(ref): return "" if ref.src_cpp_type == AT_TENSOR: ret_str = f"{self.graph}{self.dot}copy_into_staging(" ret_str += f"{ref.name}.staging, " ret_str += f"{ref.src_cpp_name}.const_data_ptr(), " ret_str += f"{ref.src_cpp_name}.numel());\n" elif ref.src_cpp_type == AT_TENSOR_LIST: ret_str = "" ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n" ret_str += f" {self.graph}{self.dot}copy_into_staging(" ret_str += f"{ref.name}_io_value_refs[i].staging, " ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), " ret_str += f"{ref.src_cpp_name}[i].numel());\n" ret_str += "}\n" else: raise AssertionError(f"{ref.src_cpp_type} not expected") return ret_str def declare_vk_out_for(self, ref: Union[ValueRef, List[ValueRef]]) -> str: if isinstance(ref, list): ret_str = "" for r in ref[:-1]: ret_str += self.declare_vk_out_for(r) return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: assert ref.is_out ret_str = f""" std::vector {ref.vk_out}; for (int i=0; i str: if isinstance(ref, list): ret_str = "" for r in ref[:-1]: ret_str += self.copy_from_staging(r) return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: assert ref.is_out ret_str = f""" for (int i=0; i str: if isinstance(ref, list): ret_str = "" for r in ref[:-1]: ret_str += self.check_graph_out(r) return ret_str elif ref.src_cpp_type == TENSOR_VECTOR: assert ref.is_out ret_str = f""" for (int i=0; i str: ret_str = "" for aten_arg in self.args: ref = self.refs[aten_arg.name] ret_str += self.create_value_decl_for(ref) ret_str += self.create_value_decl_for(self.refs["out"]) ret_str += f"{self.out.cpp_type} out;\n" ret_str += self.gen_output_staging_valueref_decl(self.refs["out"]) return ret_str def gen_graph_build_code(self, include_declarations: bool = True) -> str: graph_build = self.create_out_src(include_declarations) for aten_arg in self.args: graph_build += self.create_value_for( self.refs[aten_arg.name], include_declarations ) graph_build += self.create_value_for(self.refs["out"], include_declarations) graph_build += self.create_op_call() graph_build += self.set_output(self.refs["out"], include_declarations) graph_build += f"{self.graph}{self.dot}prepare();\n" graph_build += f"{self.graph}{self.dot}encode_prepack();\n" graph_build += f"{self.graph}{self.dot}prepack();\n" graph_build += f"{self.graph}{self.dot}encode_execute();\n" graph_build += "\n" return graph_build def gen_graph_exec_code(self, check_output=True) -> str: graph_exec = "" for aten_arg in self.args: ref = self.refs[aten_arg.name] if ref.is_in: graph_exec += self.virtual_resize(ref) graph_exec += self.copy_into_staging(ref) graph_exec += f"{self.graph}{self.dot}propagate_resize();\n" graph_exec += f"{self.graph}{self.dot}execute();\n" graph_exec += self.declare_vk_out_for(self.refs["out"]) graph_exec += self.copy_from_staging(self.refs["out"]) if check_output: graph_exec += self.check_graph_out(self.refs["out"]) graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M) graph_exec = "{\n" + graph_exec + "\n}" return graph_exec def gen_conditional_skips(self, skip_str: str = "GTEST_SKIP();") -> str: fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n" fp16_skip += f" {skip_str}\n" fp16_skip += "}" fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n" int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n" int8_skip += f" {skip_str};\n" int8_skip += "}\n" skips = "" skips += "if (test_dtype == at::kHalf) {\n" skips += fp16_skip skips += "}\n" for _, dtype in self.suite_def.arg_dtype.items(): if dtype == "at::kChar" or dtype == "at::kQInt8": skips += int8_skip continue skips += "\n" return skips def gen_op_check_fn(self) -> str: op_name = self.f.func.name.unambiguous_name() if self.suite_def.test_name_suffix is not None: op_name += "_" + self.suite_def.test_name_suffix op_check_fn = self.gen_decl(f"check_{op_name}") + " {\n" if self.should_prepack: op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n" op_check_fn_body = "" op_check_fn_body += self.gen_conditional_skips() op_check_fn_body += self.gen_graph_build_code() op_check_fn_body += self.gen_graph_exec_code() op_check_fn_body = re.sub(r"^", " ", op_check_fn_body, flags=re.M) op_check_fn += op_check_fn_body op_check_fn += "\n }" return op_check_fn def gen_build_graph_fn(self, include_declarations: bool = False) -> str: op_name = self.f.func.name.unambiguous_name() if self.suite_def.test_name_suffix is not None: op_name += "_" + self.suite_def.test_name_suffix op_build_graph_fn = self.gen_decl(f"build_graph_{op_name}") + " {\n" if self.should_prepack: op_build_graph_fn = ( self.gen_decl(f"prepacked_build_graph_{op_name}") + " {\n" ) op_build_graph_fn_body = "" op_build_graph_fn_body += self.gen_graph_build_code(include_declarations) op_build_graph_fn += op_build_graph_fn_body op_build_graph_fn += "\n }" return op_build_graph_fn def gen_op_exec_graph_fn(self) -> str: op_name = self.f.func.name.unambiguous_name() if self.suite_def.test_name_suffix is not None: op_name += "_" + self.suite_def.test_name_suffix op_benchmark_fn = self.gen_decl(f"benchmark_{op_name}") + " {\n" if self.should_prepack: op_benchmark_fn = self.gen_decl(f"prepacked_benchmark_{op_name}") + " {\n" op_benchmark_fn_body = "" op_benchmark_fn_body += self.gen_graph_exec_code(False) op_benchmark_fn_body = re.sub(r"^", " ", op_benchmark_fn_body, flags=re.M) op_benchmark_fn += op_benchmark_fn_body op_benchmark_fn += "\n }" return op_benchmark_fn