1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_AOT_CODEGEN_H_ 17 #define TENSORFLOW_COMPILER_AOT_CODEGEN_H_ 18 19 #include <string> 20 #include <vector> 21 22 #include "absl/strings/string_view.h" 23 #include "tensorflow/compiler/aot/compile.h" 24 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" 25 26 namespace tensorflow { 27 namespace tfcompile { 28 29 // CodegenOpts specifies code generation options for the generated header file 30 // and the generated metadata object file. 31 struct CodegenOpts { 32 // The name of the generated C++ class, wrapping the generated function. 33 string class_name; 34 35 // Target triple for the architecture we're targeting. 36 string target_triple; 37 38 // Namespaces specifies a list of C++ namespaces to add to the generated 39 // header. If empty, all symbols will be in the global namespace. 40 std::vector<string> namespaces; 41 42 // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. 43 bool gen_name_to_index = false; 44 45 // If true, generate program shape data for the ProgramShape method. 46 bool gen_program_shape = false; 47 48 // If true, emit a serialized HloProfilePrinterData protobuf that can be used 49 // to pretty print HLO profile counters. 50 bool gen_hlo_profile_printer_data = false; 51 }; 52 53 // Describes a generated metadata object file. 54 struct MetadataResult { 55 // These are top level "extern C" declarations that are expected to be visible 56 // wherever program_shape_access_shim is emitted. 57 std::vector<string> header_variable_decls; 58 59 // program_shape_access_shim is a C++ expression that constructs the 60 // xla::ProgramShapeProto instance for the CompileResult passed to 61 // GenerateMetadata. 62 string program_shape_access_shim; 63 64 // hlo_profile_printer_data_access_shim is a C++ expression that constructs 65 // the xla::HloProfilePrinterData instance for the CompileResult passed to 66 // GenerateMetadata. If the xla::HloProfilePrinterData is null then this is a 67 // C++ expression that evaluates to nullptr at runtime. 68 string hlo_profile_printer_data_access_shim; 69 70 // The contents of the object (".o") file. 71 string object_file_data; 72 }; 73 74 // Generates a metadata object file according to `opts` and `compile_result`. 75 // The generated object file is returned via `metadata_result`. 76 Status GenerateMetadata(const CodegenOpts& opts, 77 const CompileResult& compile_result, 78 MetadataResult* metadata_result); 79 80 // GenerateHeader uses the meta-information from compile_result to generate a 81 // C++ header giving access to the function in the generated object file. The 82 // header includes API usage documentation. 83 // 84 // metadata_result is an instance of MetadataResult obtained by a previous 85 // invocation to GenerateMetadata. 86 Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, 87 const CompileResult& compile_result, 88 const MetadataResult& metadata_result, string* header); 89 90 // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` 91 // components. The syntax is [[<optional_namespace>::],...]<class_name>. This 92 // mirrors the C++ syntax for referring to a class, where multiple namespaces 93 // may precede the class name, separated by double-colons. 94 Status ParseCppClass(const string& cpp_class, string* class_name, 95 std::vector<string>* namespaces); 96 97 // ValidateCppIdent returns OK iff ident is a valid C++ identifier. The msg is 98 // appended to error messages. 99 Status ValidateCppIdent(absl::string_view ident, absl::string_view msg); 100 101 } // namespace tfcompile 102 } // namespace tensorflow 103 104 #endif // TENSORFLOW_COMPILER_AOT_CODEGEN_H_ 105