1from __future__ import annotations 2 3import argparse 4import itertools 5import os 6from typing import Sequence, TypeVar, Union 7 8from libfb.py.log import set_simple_logging # type: ignore[import] 9 10from torchgen import gen 11from torchgen.context import native_function_manager 12from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup 13from torchgen.static_runtime import config, generator 14 15 16# Given a list of `grouped_native_functions` sorted by their op names, return a list of 17# lists each of which groups ops that share the base name. For example, `mean` and 18# `mean.dim` are grouped together by this function. 19 20NativeGroupT = TypeVar( 21 "NativeGroupT", 22 bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup], 23) 24 25 26def group_functions_by_op_name( 27 grouped_native_functions: Sequence[NativeGroupT], 28) -> Sequence[Sequence[NativeGroupT]]: 29 if not grouped_native_functions: 30 return [] 31 groups = [] 32 33 def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool: 34 with native_function_manager(g): 35 return generator.is_supported(g) 36 37 eligible_ops = (g for g in grouped_native_functions if is_supported(g)) 38 groups = [ 39 list(group) 40 for k, group in ( 41 itertools.groupby( 42 eligible_ops, 43 key=config.func_name_base_str, 44 ) 45 ) 46 ] 47 48 return groups 49 50 51def clang_format(cpp_file_path: str) -> None: 52 import subprocess 53 54 subprocess.check_call(["clang-format", "-i", cpp_file_path]) 55 56 57def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None: 58 code = "\n".join(cpp_ops) 59 generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN 60// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py 61#include <torch/csrc/jit/runtime/static/ops.h> 62 63#include <ATen/CPUFunctions.h> 64#include <ATen/InferSize.h> 65#include <ATen/NativeFunctions.h> 66#include <ATen/Parallel.h> 67#include <ATen/ScalarOps.h> 68#include <ATen/TensorUtils.h> 69#include <ATen/cpu/vec/functional.h> 70#include <ATen/cpu/vec/vec.h> 71#include <ATen/native/EmbeddingBag.h> 72#include <ATen/native/Fill.h> 73#include <ATen/native/IndexingUtils.h> 74#include <ATen/native/NonSymbolicBC.h> 75#include <ATen/native/Resize.h> 76#include <ATen/native/SharedReduceOps.h> 77#include <ATen/native/TensorAdvancedIndexing.h> 78#include <ATen/native/cpu/SerialStackImpl.h> 79#include <ATen/native/layer_norm.h> 80#include <ATen/native/quantized/cpu/fbgemm_utils.h> 81#include <ATen/native/quantized/cpu/qembeddingbag.h> 82#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h> 83#include <ATen/quantized/QTensorImpl.h> 84#include <ATen/quantized/Quantizer.h> 85#include <c10/core/ScalarType.h> 86#include <c10/core/WrapDimMinimal.h> 87#include <c10/util/irange.h> 88#include <torch/csrc/jit/ir/ir.h> 89#include <torch/csrc/jit/runtime/static/impl.h> 90#include <torch/csrc/jit/runtime/static/te_wrapper.h> 91#include <torch/csrc/jit/runtime/vararg_functions.h> 92#include <torch/csrc/jit/tensorexpr/ir.h> 93#include <torch/csrc/jit/tensorexpr/ir_simplifier.h> 94#include <torch/csrc/jit/tensorexpr/llvm_codegen.h> 95#include <torch/csrc/jit/tensorexpr/loopnest.h> 96 97namespace torch {{ 98namespace jit {{ 99 100{code} 101 102}} // namespace jit 103}} // namespace torch 104""" 105 with open(file_path, "w") as f: 106 f.write(generated) 107 clang_format(file_path) 108 109 110def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None: 111 code = "\n".join(cpp_ops) 112 generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN 113// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py 114#include <gtest/gtest.h> 115#include <torch/csrc/jit/runtime/static/impl.h> 116#include <torch/torch.h> 117 118#include "test_utils.h" 119 120using namespace caffe2; 121using namespace torch; 122using namespace torch::jit; 123using namespace torch::jit::test; 124using c10::IValue; 125 126{code} 127 128""" 129 with open(file_path, "w") as f: 130 f.write(generated) 131 clang_format(file_path) 132 133 134def main() -> None: 135 parser = argparse.ArgumentParser(description="Generate ATen source files") 136 parser.add_argument( 137 "-s", 138 "--source-path", 139 help="path to source directory for ATen", 140 default="caffe2/aten/src/ATen", 141 ) 142 parser.add_argument( 143 "-p", 144 "--generated-ops-cpp-path", 145 help="path to directory to generate op dispatcher .cpp file", 146 default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp", 147 ) 148 parser.add_argument( 149 "-t", 150 "--generated-ops-test-cpp-path", 151 help="path to directory to generate op dispatcher .cpp file", 152 default="caffe2/benchmarks/static_runtime/test_generated_ops.cc", 153 ) 154 options = parser.parse_args() 155 native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml") 156 tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml") 157 parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path) 158 native_functions, backend_indices = ( 159 parsed_yaml.native_functions, 160 parsed_yaml.backend_indices, 161 ) 162 163 op_generator = generator.GenOpDispatcher() 164 test_case_generator = generator.GenOpTestCase() 165 166 native_functions_groups = [ 167 g 168 for g in gen.get_grouped_native_functions(native_functions) 169 if isinstance(g, NativeFunctionsGroup) 170 ] 171 172 supported_functions_groups = group_functions_by_op_name(native_functions_groups) 173 174 out_variant_op_result = [ 175 op_generator.out_variant(groups, backend_indices[DispatchKey.CPU]) 176 for groups in supported_functions_groups 177 ] 178 out_variant_test_result = [ 179 test_case_generator.out_variant(groups) for groups in supported_functions_groups 180 ] 181 182 native_functions_view_groups = [ 183 g 184 for g in gen.get_grouped_by_view_native_functions(native_functions) 185 if isinstance(g, NativeFunctionsViewGroup) 186 ] 187 188 supported_functions_view_groups = group_functions_by_op_name( 189 native_functions_view_groups 190 ) 191 192 view_op_result = [ 193 op_generator.view(groups, backend_indices[DispatchKey.CPU]) 194 for groups in supported_functions_view_groups 195 ] 196 view_test_result = [ 197 test_case_generator.view(groups) for groups in supported_functions_view_groups 198 ] 199 200 op_result = out_variant_op_result + ["\n\n"] + view_op_result 201 test_result = out_variant_test_result + ["\n\n"] + view_test_result 202 203 write_cpp(op_result, options.generated_ops_cpp_path) 204 write_test_cpp(test_result, options.generated_ops_test_cpp_path) 205 206 print( 207 "\ntotal grouped native ops: %d" 208 % len(gen.get_grouped_native_functions(native_functions)) 209 ) 210 211 print("grouped native ops with out variant: %d" % len(native_functions_groups)) 212 supported_functions_num = sum(len(groups) for groups in supported_functions_groups) 213 print("generated functions groups with out variant: %d" % supported_functions_num) 214 215 print("\nview grouped native ops: %d" % len(native_functions_view_groups)) 216 supported_view_functions_num = sum( 217 len(groups) for groups in supported_functions_view_groups 218 ) 219 print("generated functions view groups: %d" % supported_view_functions_num) 220 221 print( 222 "\noverall generated : %d" 223 % (supported_functions_num + supported_view_functions_num) 224 ) 225 226 227if __name__ == "__main__": 228 set_simple_logging(escape_newlines=False) 229 main() 230