# Generates C++ functions that wrap ATen tensor factory methods to turn them into Variables. # # This writes one file: variable_factories.h from __future__ import annotations import re import torchgen.api.python as python from torchgen.api import cpp from torchgen.api.types import CppSignatureGroup from torchgen.context import with_native_function from torchgen.gen import parse_native_yaml from torchgen.model import NativeFunction, TensorOptionsArguments, Variant from torchgen.utils import FileManager, mapMaybe OPTIONAL_TYPE_PATTERN = re.compile(r"std::optional<(.+)>") TYPE_PATTERN = re.compile(r"(?:const\s+)?([A-Z]\w+)") # Add 'at::' to types defined in ATen namespace, e.g. Tensor, TensorList, IntArrayRef and etc. # TODO: maybe update the cpp argument API to take optional namespace argument? def fully_qualified_type(argument_type: str) -> str: def maybe_optional_type(type: str, is_opt: bool) -> str: return f"std::optional<{type}>" if is_opt else type opt_match = OPTIONAL_TYPE_PATTERN.match(argument_type) is_opt = opt_match is not None if opt_match: argument_type = argument_type[opt_match.start(1) : opt_match.end(1)] match = TYPE_PATTERN.match(argument_type) if match is None: return maybe_optional_type(argument_type, is_opt) index = match.start(1) qualified_type = f"{argument_type[:index]}at::{argument_type[index:]}" return maybe_optional_type(qualified_type, is_opt) def gen_variable_factories( out: str, native_yaml_path: str, tags_yaml_path: str, template_path: str ) -> None: native_functions = parse_native_yaml( native_yaml_path, tags_yaml_path ).native_functions factory_functions = [fn for fn in native_functions if is_factory_function(fn)] fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) fm.write_with_template( "variable_factories.h", "variable_factories.h", lambda: { "generated_comment": "@" + f"generated from {fm.template_dir_for_comments()}/variable_factories.h", "ops_headers": [ f"#include " for fn in factory_functions ], "function_definitions": list(mapMaybe(process_function, factory_functions)), }, ) @with_native_function def is_factory_function(f: NativeFunction) -> bool: if Variant.function not in f.variants: return False name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) return has_tensor_options or name.endswith("_like") @with_native_function def process_function(f: NativeFunction) -> str | None: name = cpp.name(f.func) has_tensor_options = python.has_tensor_options(f) is_factory = has_tensor_options or name.endswith("_like") if Variant.function not in f.variants or not is_factory: return None cpp_sigs = CppSignatureGroup.from_native_function(f, method=False) sigs = [cpp_sigs.signature] if cpp_sigs.symint_signature is not None: sigs.append(cpp_sigs.symint_signature) r = "" for sig in sigs: formals: list[str] = [] exprs: list[str] = [] requires_grad = "false" for arg in sig.arguments(): qualified_type = fully_qualified_type(arg.type) if arg.default: formals.append(f"{qualified_type} {arg.name} = {arg.default}") else: formals.append(f"{qualified_type} {arg.name}") if isinstance(arg.argument, TensorOptionsArguments): # note: we remove the requires_grad setting from the TensorOptions because # it is ignored anyways (and we actually have an assertion that it isn't set # which would fail otherwise). We handle requires_grad explicitly here # instead of passing it through to the kernel. exprs.append( f"at::TensorOptions({arg.name}).requires_grad(::std::nullopt)" ) # Manually set the requires_grad bit on the result tensor. requires_grad = f"{arg.name}.requires_grad()" else: exprs.append(arg.name) r += f"""\ inline at::Tensor {sig.name()}({', '.join(formals)}) {{ at::AutoDispatchBelowADInplaceOrView guard; return autograd::make_variable(at::{sig.name()}({', '.join(exprs)}), /*requires_grad=*/{requires_grad}); }} """ return r