• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1from __future__ import annotations
2
3import argparse
4import itertools
5import os
6from typing import Sequence, TypeVar, Union
7
8from libfb.py.log import set_simple_logging  # type: ignore[import]
9
10from torchgen import gen
11from torchgen.context import native_function_manager
12from torchgen.model import DispatchKey, NativeFunctionsGroup, NativeFunctionsViewGroup
13from torchgen.static_runtime import config, generator
14
15
16# Given a list of `grouped_native_functions` sorted by their op names, return a list of
17# lists each of which groups ops that share the base name. For example, `mean` and
18# `mean.dim` are grouped together by this function.
19
20NativeGroupT = TypeVar(
21    "NativeGroupT",
22    bound=Union[NativeFunctionsGroup, NativeFunctionsViewGroup],
23)
24
25
26def group_functions_by_op_name(
27    grouped_native_functions: Sequence[NativeGroupT],
28) -> Sequence[Sequence[NativeGroupT]]:
29    if not grouped_native_functions:
30        return []
31    groups = []
32
33    def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
34        with native_function_manager(g):
35            return generator.is_supported(g)
36
37    eligible_ops = (g for g in grouped_native_functions if is_supported(g))
38    groups = [
39        list(group)
40        for k, group in (
41            itertools.groupby(
42                eligible_ops,
43                key=config.func_name_base_str,
44            )
45        )
46    ]
47
48    return groups
49
50
51def clang_format(cpp_file_path: str) -> None:
52    import subprocess
53
54    subprocess.check_call(["clang-format", "-i", cpp_file_path])
55
56
57def write_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
58    code = "\n".join(cpp_ops)
59    generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
60// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
61#include <torch/csrc/jit/runtime/static/ops.h>
62
63#include <ATen/CPUFunctions.h>
64#include <ATen/InferSize.h>
65#include <ATen/NativeFunctions.h>
66#include <ATen/Parallel.h>
67#include <ATen/ScalarOps.h>
68#include <ATen/TensorUtils.h>
69#include <ATen/cpu/vec/functional.h>
70#include <ATen/cpu/vec/vec.h>
71#include <ATen/native/EmbeddingBag.h>
72#include <ATen/native/Fill.h>
73#include <ATen/native/IndexingUtils.h>
74#include <ATen/native/NonSymbolicBC.h>
75#include <ATen/native/Resize.h>
76#include <ATen/native/SharedReduceOps.h>
77#include <ATen/native/TensorAdvancedIndexing.h>
78#include <ATen/native/cpu/SerialStackImpl.h>
79#include <ATen/native/layer_norm.h>
80#include <ATen/native/quantized/cpu/fbgemm_utils.h>
81#include <ATen/native/quantized/cpu/qembeddingbag.h>
82#include <ATen/native/quantized/cpu/qembeddingbag_prepack.h>
83#include <ATen/quantized/QTensorImpl.h>
84#include <ATen/quantized/Quantizer.h>
85#include <c10/core/ScalarType.h>
86#include <c10/core/WrapDimMinimal.h>
87#include <c10/util/irange.h>
88#include <torch/csrc/jit/ir/ir.h>
89#include <torch/csrc/jit/runtime/static/impl.h>
90#include <torch/csrc/jit/runtime/static/te_wrapper.h>
91#include <torch/csrc/jit/runtime/vararg_functions.h>
92#include <torch/csrc/jit/tensorexpr/ir.h>
93#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
94#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
95#include <torch/csrc/jit/tensorexpr/loopnest.h>
96
97namespace torch {{
98namespace jit {{
99
100{code}
101
102}} // namespace jit
103}} // namespace torch
104"""
105    with open(file_path, "w") as f:
106        f.write(generated)
107    clang_format(file_path)
108
109
110def write_test_cpp(cpp_ops: Sequence[str], file_path: str) -> None:
111    code = "\n".join(cpp_ops)
112    generated = f"""// @lint-ignore-every CLANGTIDY HOWTOEVEN
113// AUTO-GENERATED FROM: torchgen/static_runtime/gen_static_runtime_ops.py
114#include <gtest/gtest.h>
115#include <torch/csrc/jit/runtime/static/impl.h>
116#include <torch/torch.h>
117
118#include "test_utils.h"
119
120using namespace caffe2;
121using namespace torch;
122using namespace torch::jit;
123using namespace torch::jit::test;
124using c10::IValue;
125
126{code}
127
128"""
129    with open(file_path, "w") as f:
130        f.write(generated)
131    clang_format(file_path)
132
133
134def main() -> None:
135    parser = argparse.ArgumentParser(description="Generate ATen source files")
136    parser.add_argument(
137        "-s",
138        "--source-path",
139        help="path to source directory for ATen",
140        default="caffe2/aten/src/ATen",
141    )
142    parser.add_argument(
143        "-p",
144        "--generated-ops-cpp-path",
145        help="path to directory to generate op dispatcher .cpp file",
146        default="caffe2/torch/csrc/jit/runtime/static/generated_ops.cpp",
147    )
148    parser.add_argument(
149        "-t",
150        "--generated-ops-test-cpp-path",
151        help="path to directory to generate op dispatcher .cpp file",
152        default="caffe2/benchmarks/static_runtime/test_generated_ops.cc",
153    )
154    options = parser.parse_args()
155    native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
156    tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
157    parsed_yaml = gen.parse_native_yaml(native_yaml_path, tags_yaml_path)
158    native_functions, backend_indices = (
159        parsed_yaml.native_functions,
160        parsed_yaml.backend_indices,
161    )
162
163    op_generator = generator.GenOpDispatcher()
164    test_case_generator = generator.GenOpTestCase()
165
166    native_functions_groups = [
167        g
168        for g in gen.get_grouped_native_functions(native_functions)
169        if isinstance(g, NativeFunctionsGroup)
170    ]
171
172    supported_functions_groups = group_functions_by_op_name(native_functions_groups)
173
174    out_variant_op_result = [
175        op_generator.out_variant(groups, backend_indices[DispatchKey.CPU])
176        for groups in supported_functions_groups
177    ]
178    out_variant_test_result = [
179        test_case_generator.out_variant(groups) for groups in supported_functions_groups
180    ]
181
182    native_functions_view_groups = [
183        g
184        for g in gen.get_grouped_by_view_native_functions(native_functions)
185        if isinstance(g, NativeFunctionsViewGroup)
186    ]
187
188    supported_functions_view_groups = group_functions_by_op_name(
189        native_functions_view_groups
190    )
191
192    view_op_result = [
193        op_generator.view(groups, backend_indices[DispatchKey.CPU])
194        for groups in supported_functions_view_groups
195    ]
196    view_test_result = [
197        test_case_generator.view(groups) for groups in supported_functions_view_groups
198    ]
199
200    op_result = out_variant_op_result + ["\n\n"] + view_op_result
201    test_result = out_variant_test_result + ["\n\n"] + view_test_result
202
203    write_cpp(op_result, options.generated_ops_cpp_path)
204    write_test_cpp(test_result, options.generated_ops_test_cpp_path)
205
206    print(
207        "\ntotal grouped native ops: %d"
208        % len(gen.get_grouped_native_functions(native_functions))
209    )
210
211    print("grouped native ops with out variant: %d" % len(native_functions_groups))
212    supported_functions_num = sum(len(groups) for groups in supported_functions_groups)
213    print("generated functions groups with out variant: %d" % supported_functions_num)
214
215    print("\nview grouped native ops: %d" % len(native_functions_view_groups))
216    supported_view_functions_num = sum(
217        len(groups) for groups in supported_functions_view_groups
218    )
219    print("generated functions view groups: %d" % supported_view_functions_num)
220
221    print(
222        "\noverall generated : %d"
223        % (supported_functions_num + supported_view_functions_num)
224    )
225
226
227if __name__ == "__main__":
228    set_simple_logging(escape_newlines=False)
229    main()
230