from __future__ import annotations from collections import defaultdict from dataclasses import dataclass from typing import Sequence, TYPE_CHECKING from torchgen import dest # disable import sorting to avoid circular dependency. from torchgen.api.types import DispatcherSignature # usort: skip from torchgen.context import method_with_native_function from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant from torchgen.utils import concatMap, Target if TYPE_CHECKING: from torchgen.executorch.model import ETKernelIndex from torchgen.selective_build.selector import SelectiveBuilder # Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at # model authoring side. @dataclass(frozen=True) class ComputeNativeFunctionStub: @method_with_native_function def __call__(self, f: NativeFunction) -> str | None: if Variant.function not in f.variants: return None sig = DispatcherSignature.from_schema( f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False ) assert sig is not None if len(f.func.returns) == 0: ret_name = "" elif len(f.func.returns) == 1: if f.func.arguments.out: ret_name = f.func.arguments.out[0].name else: ret_name = next( ( a.name for a in f.func.arguments.flat_non_out if a.type == f.func.returns[0].type ), "", ) if not ret_name: # if return type is tensor if f.func.returns[0].type == BaseType(BaseTy.Tensor): # Returns an empty tensor ret_name = "at::Tensor()" else: raise Exception( # noqa: TRY002 f"Can't handle this return type {f.func}" ) # noqa: TRY002 elif len(f.func.arguments.out) == len(f.func.returns): # Returns a tuple of out arguments tensor_type = "at::Tensor &" comma = ", " ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( {comma.join([r.name for r in f.func.arguments.out])} )""" else: assert all( a.type == BaseType(BaseTy.Tensor) for a in f.func.returns ), f"Only support tensor returns but got {f.func.returns}" # Returns a tuple of empty tensors tensor_type = "at::Tensor" comma = ", " ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>( {comma.join(["at::Tensor()" for _ in f.func.returns])} )""" ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else "" return f""" {sig.defn()} {{ {ret_str} }} """ def gen_custom_ops_registration( *, native_functions: Sequence[NativeFunction], selector: SelectiveBuilder, kernel_index: ETKernelIndex, rocm: bool, ) -> tuple[str, str]: """ Generate custom ops registration code for dest.RegisterDispatchKey. :param native_functions: a sequence of `NativeFunction` :param selector: for selective build. :param kernel_index: kernels for all the ops. :param rocm: bool for dest.RegisterDispatchKey. :return: generated C++ code to register custom operators into PyTorch """ # convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet. # TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex. dispatch_key = DispatchKey.CPU backend_index = kernel_index._to_backend_index() static_init_dispatch_registrations = "" ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list) for native_function in native_functions: ns_grouped_native_functions[native_function.namespace].append(native_function) for namespace, functions in ns_grouped_native_functions.items(): if len(functions) == 0: continue dispatch_registrations_body = "\n".join( list( concatMap( dest.RegisterDispatchKey( backend_index, Target.REGISTRATION, selector, rocm=rocm, symint=False, class_method_name=None, skip_dispatcher_op_registration=False, ), functions, ) ) ) static_init_dispatch_registrations += f""" TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{ {dispatch_registrations_body} }};""" anonymous_definition = "\n".join( list( concatMap( dest.RegisterDispatchKey( backend_index, Target.ANONYMOUS_DEFINITION, selector, rocm=rocm, symint=False, class_method_name=None, skip_dispatcher_op_registration=False, ), native_functions, ) ) ) return anonymous_definition, static_init_dispatch_registrations