1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/aot/codegen.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
31 #include "tensorflow/compiler/xla/service/compiler.h"
32 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36
37 namespace tensorflow {
38 namespace tfcompile {
39
40 namespace {
41
42 using BufferInfo = xla::cpu_function_runtime::BufferInfo;
43
IsAlpha(char c)44 bool IsAlpha(char c) {
45 return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
46 }
47
IsAlphaNum(char c)48 bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
49
50 // Convert an XLA type into a C++ type.
XLATypeToCpp(xla::PrimitiveType type,string * str)51 Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
52 switch (type) {
53 case xla::PRED:
54 *str = "bool";
55 break;
56 case xla::S8:
57 *str = "tensorflow::int8";
58 break;
59 case xla::S16:
60 *str = "tensorflow::int16";
61 break;
62 case xla::S32:
63 *str = "tensorflow::int32";
64 break;
65 case xla::S64:
66 *str = "tensorflow::int64";
67 break;
68 case xla::U8:
69 *str = "tensorflow::uint8";
70 break;
71 case xla::U16:
72 *str = "tensorflow::uint16";
73 break;
74 case xla::U32:
75 *str = "tensorflow::uint32";
76 break;
77 case xla::U64:
78 *str = "tensorflow::uint64";
79 break;
80 case xla::F32:
81 *str = "float";
82 break;
83 case xla::F64:
84 *str = "double";
85 break;
86 default:
87 return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
88 " has no equivalent in C++");
89 }
90 return Status::OK();
91 }
92
93 // Returns the sum of the size of each buffer in `buffer_infos`.
TotalBufferBytes(const std::vector<BufferInfo> & buffer_infos)94 size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
95 return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
96 [](size_t size, const BufferInfo& buffer_info) {
97 return size + buffer_info.size();
98 });
99 }
100
101 // Returns a vector of BufferInfo instances in `buffer_infos` that are entry
102 // parameter buffers.
ExtractEntryParamBufferInfos(const std::vector<BufferInfo> & buffer_infos)103 std::vector<BufferInfo> ExtractEntryParamBufferInfos(
104 const std::vector<BufferInfo>& buffer_infos) {
105 std::vector<BufferInfo> result;
106 std::copy_if(buffer_infos.begin(), buffer_infos.end(),
107 std::back_inserter(result), [](const BufferInfo& buffer_info) {
108 return buffer_info.is_entry_parameter();
109 });
110 return result;
111 }
112
113 // Returns a vector of BufferInfo instances in `buffer_infos` that are temp
114 // buffers.
ExtractTempBufferInfos(const std::vector<BufferInfo> & buffer_infos)115 std::vector<BufferInfo> ExtractTempBufferInfos(
116 const std::vector<BufferInfo>& buffer_infos) {
117 std::vector<BufferInfo> result;
118 std::copy_if(buffer_infos.begin(), buffer_infos.end(),
119 std::back_inserter(result), [](const BufferInfo& buffer_info) {
120 return buffer_info.is_temp_buffer();
121 });
122 return result;
123 }
124
125 // Add (from,to) rewrite pairs based on the given shape. These rewrite pairs
126 // are used to generate methods for args and results.
AddRewritesForShape(int i,const xla::Shape & shape,std::vector<std::pair<string,string>> * rewrites)127 Status AddRewritesForShape(int i, const xla::Shape& shape,
128 std::vector<std::pair<string, string>>* rewrites) {
129 string type;
130 TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
131 std::vector<string> dim_vars;
132 string dim_sizes, indices;
133 if (shape.rank() == 0 ||
134 (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
135 dim_sizes = "[1]";
136 indices = "[0]";
137 } else {
138 for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
139 dim_vars.push_back(absl::StrCat("size_t dim", dim));
140 dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
141 indices += absl::StrCat("[dim", dim, "]");
142 }
143 }
144 rewrites->push_back({"{{I}}", absl::StrCat(i)});
145 rewrites->push_back({"{{TYPE}}", type});
146 rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
147 rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
148 rewrites->push_back({"{{INDICES}}", indices});
149 return Status::OK();
150 }
151
152 // Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
153 // for the name. Note that the rewriting strategy is roughly O(N*M), where N is
154 // the size of the code and M is the number of rewrites. It's fine for now
155 // since N and M are pretty small.
156 //
157 // TODO(toddw): If this becomes a problem, we should be able to change the
158 // algorithm to O(N) by using a state machine, e.g. regexps or a real
159 // text-templating mechanism.
RewriteWithName(const string & name,string code,const std::vector<std::pair<string,string>> & rewrites)160 string RewriteWithName(const string& name, string code,
161 const std::vector<std::pair<string, string>>& rewrites) {
162 absl::StrReplaceAll(rewrites, &code);
163 absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
164 return code;
165 }
166
167 // Generate methods for args (inputs).
GenArgMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,const CompileResult & compile_result,string * methods)168 Status GenArgMethods(const tf2xla::Config& config,
169 const xla::ProgramShapeProto& ps,
170 const CompileResult& compile_result, string* methods) {
171 size_t num_args = ps.parameters_size();
172 if (config.feed_size() + config.variable_size() != num_args) {
173 return errors::InvalidArgument(
174 "mismatch between feed_size(", config.feed_size(), ")+variable_size(",
175 config.variable_size(), ") and num_args(", num_args, ")");
176 }
177 for (int i = 0; i < config.feed_size(); ++i) {
178 std::vector<std::pair<string, string>> rewrites;
179 TF_RETURN_IF_ERROR(
180 AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
181 const string code = R"(
182 void set_arg{{NAME}}_data(const void* data) {
183 set_arg_data({{I}}, data);
184 }
185 {{TYPE}}* arg{{NAME}}_data() {
186 return static_cast<{{TYPE}}*>(arg_data({{I}}));
187 }
188 {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
189 return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
190 arg_data({{I}}))){{INDICES}};
191 }
192 const {{TYPE}}* arg{{NAME}}_data() const {
193 return static_cast<const {{TYPE}}*>(arg_data({{I}}));
194 }
195 const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
196 return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
197 arg_data({{I}}))){{INDICES}};
198 }
199 )";
200 *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
201 if (!config.feed(i).name().empty()) {
202 *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
203 }
204 }
205 return Status::OK();
206 }
207
208 // Generate methods for results (outputs).
GenResultMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)209 Status GenResultMethods(const tf2xla::Config& config,
210 const xla::ProgramShapeProto& ps, string* methods) {
211 if (ps.result().element_type() != xla::TUPLE) {
212 // The XlaCompiler we use to build the xla computation always generates a
213 // tuple result, and we rely on this to simplify code generation.
214 return errors::Internal("codegen requires the XLA result to be a tuple");
215 }
216 size_t num_results = ps.result().tuple_shapes_size();
217 int readonly_variables = absl::c_count_if(
218 config.variable(),
219 [](const tf2xla::Variable& var) { return var.readonly(); });
220 if (config.fetch_size() + config.variable_size() - readonly_variables !=
221 num_results) {
222 return errors::InvalidArgument("mismatch between fetch_size(",
223 config.fetch_size(), ")+variable_size(",
224 config.variable_size(), ") and tuple_size(",
225 ps.result().tuple_shapes_size(), ")");
226 }
227 for (int i = 0; i < config.fetch_size(); ++i) {
228 std::vector<std::pair<string, string>> rewrites;
229 TF_RETURN_IF_ERROR(AddRewritesForShape(
230 i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
231 string code = R"(
232 {{TYPE}}* result{{NAME}}_data() {
233 return static_cast<{{TYPE}}*>(result_data({{I}}));
234 }
235 {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
236 return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
237 result_data({{I}}))){{INDICES}};
238 }
239 const {{TYPE}}* result{{NAME}}_data() const {
240 return static_cast<const {{TYPE}}*>(result_data({{I}}));
241 }
242 const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
243 return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
244 result_data({{I}}))){{INDICES}};
245 }
246 )";
247 *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
248 if (!config.fetch(i).name().empty()) {
249 *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
250 }
251 }
252 return Status::OK();
253 }
254
255 // Generate methods for variables.
GenVariableMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)256 Status GenVariableMethods(const tf2xla::Config& config,
257 const xla::ProgramShapeProto& ps, string* methods) {
258 size_t num_args = ps.parameters_size();
259 for (int i = config.feed_size(); i < num_args; ++i) {
260 std::vector<std::pair<string, string>> rewrites;
261 TF_RETURN_IF_ERROR(
262 AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
263 const string code = R"(
264 void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
265 set_arg_data({{I}}, data);
266 }
267 {{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
268 return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
269 }
270 {{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
271 return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
272 arg_data({{I}}))){{INDICES}};
273 }
274 const {{TYPE}}* var_{{NAME}}_data() const {
275 return static_cast<const {{TYPE}}*>(arg_data({{I}}));
276 }
277 const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
278 return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
279 arg_data({{I}}))){{INDICES}};
280 }
281 )";
282 const tf2xla::Variable& var = config.variable(i - config.feed_size());
283 rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
284 *methods += RewriteWithName(
285 var.name().empty() ? var.node_name() : var.name(), code, rewrites);
286 }
287 return Status::OK();
288 }
289
290 // Generates code implementing {Arg,Result}Names(), where T is one of
291 // tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
292 // literal in the array, with nullptr terminating the array.
293 template <typename T>
GenNameToIndexCode(const T & entries,bool generate)294 string GenNameToIndexCode(const T& entries, bool generate) {
295 // No need for a static array if we're not supposed to generate the data.
296 if (!generate) {
297 return "{\n return nullptr;\n }";
298 }
299 // Determine when to stop. We stop emitting string literals after the last
300 // non-empty name.
301 int end = entries.size();
302 for (int i = entries.size() - 1; i >= 0; --i) {
303 if (!entries[i].name().empty()) {
304 break;
305 }
306 end = i;
307 }
308 // Emit string literals up to the last non-empty name.
309 string code = "{\n static const char* kNames[] = {";
310 for (int i = 0; i < end; ++i) {
311 if (i > 0) {
312 code += ", ";
313 }
314 code += "\"";
315 code += entries[i].name();
316 code += "\"";
317 }
318 if (end > 0) {
319 code += ", ";
320 }
321 code += "nullptr};\n return kNames;\n }";
322 return code;
323 }
324
ValidateFeedFetchCppNames(const tf2xla::Config & config)325 Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
326 for (const tf2xla::Feed& feed : config.feed()) {
327 if (!feed.name().empty()) {
328 TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
329 }
330 }
331 for (const tf2xla::Fetch& fetch : config.fetch()) {
332 if (!fetch.name().empty()) {
333 TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
334 }
335 }
336 for (const tf2xla::Variable& variable : config.variable()) {
337 if (!variable.name().empty()) {
338 TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name"));
339 } else {
340 TF_RETURN_IF_ERROR(
341 ValidateCppIdent(variable.node_name(), "variable name"));
342 }
343 }
344 return Status::OK();
345 }
346
347 // Returns a list of C++ expressions that, when executed, will construct the
348 // BufferInfo instances in `buffer_infos`.
BufferInfosToCppExpression(const std::vector<BufferInfo> & buffer_infos)349 std::vector<string> BufferInfosToCppExpression(
350 const std::vector<BufferInfo>& buffer_infos) {
351 std::vector<string> buffer_infos_as_strings;
352 std::transform(buffer_infos.begin(), buffer_infos.end(),
353 std::back_inserter(buffer_infos_as_strings),
354 [](const BufferInfo& buffer_info) {
355 std::pair<uint64, uint64> encoded = buffer_info.Encode();
356 string encoded_second_as_str =
357 encoded.second == ~0ULL
358 ? "~0ULL"
359 : absl::StrCat(encoded.second, "ULL");
360 return absl::StrCat(
361 "::xla::cpu_function_runtime::BufferInfo({",
362 encoded.first, "ULL, ", encoded_second_as_str, "})");
363 });
364 return buffer_infos_as_strings;
365 }
366 } // namespace
367
GenerateHeader(const CodegenOpts & opts,const tf2xla::Config & config,const CompileResult & compile_result,const MetadataResult & metadata_result,string * header)368 Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
369 const CompileResult& compile_result,
370 const MetadataResult& metadata_result, string* header) {
371 TF_RETURN_IF_ERROR(ValidateConfig(config));
372 TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
373 const int64 result_index = compile_result.aot->result_buffer_index();
374 const std::vector<BufferInfo>& buffer_infos =
375 compile_result.aot->buffer_infos();
376 const std::vector<int32> arg_index_table =
377 ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
378 std::vector<string> buffer_infos_as_strings =
379 BufferInfosToCppExpression(buffer_infos);
380 if (result_index < 0 || result_index >= buffer_infos.size()) {
381 return errors::InvalidArgument("result index: ", result_index,
382 " is outside the range of temp sizes: [0,",
383 buffer_infos.size(), ")");
384 }
385
386 // Compute sizes and generate methods.
387 std::vector<BufferInfo> buffer_infos_for_args =
388 ExtractEntryParamBufferInfos(buffer_infos);
389 std::vector<BufferInfo> buffer_infos_for_temps =
390 ExtractTempBufferInfos(buffer_infos);
391 const xla::ProgramShapeProto& ps = compile_result.program_shape;
392 string methods_arg, methods_result, methods_variable;
393 TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
394 TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
395 TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
396 const size_t arg_bytes_aligned =
397 xla::cpu_function_runtime::AlignedBufferBytes(
398 buffer_infos_for_args.data(), buffer_infos_for_args.size(),
399 /*allocate_entry_params=*/true);
400 const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
401 const size_t temp_bytes_aligned =
402 xla::cpu_function_runtime::AlignedBufferBytes(
403 buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
404 /*allocate_entry_params=*/true);
405 const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
406
407 // Create rewrite strings for namespace start and end.
408 string ns_start;
409 for (const string& n : opts.namespaces) {
410 ns_start += absl::StrCat("namespace ", n, " {\n");
411 }
412 ns_start += "\n";
413 string ns_end("\n");
414 for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
415 const string& n = opts.namespaces[i];
416 ns_end += absl::StrCat("} // end namespace ", n, "\n");
417 }
418
419 // Generate metadata.
420 const string arg_names_code =
421 GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
422 const string result_names_code =
423 GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
424 const string include_xla_data_proto =
425 opts.gen_program_shape
426 ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
427 : "";
428
429 const string include_hlo_profile_printer_data_proto =
430 opts.gen_hlo_profile_printer_data
431 ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
432 : "";
433
434 // When HLO profiling is disabled we only forward declare the
435 // HloProfilePrinter protobuf. So we can only conditionally emit this code
436 // calling HloProfilePrinter::profile_counters_size.
437 const string assign_profile_counters_size =
438 opts.gen_hlo_profile_printer_data
439 ? "set_static_data_profile_counters_size(data, "
440 "get_static_data_hlo_profile_printer_data(data)->"
441 "profile_counters_size());"
442 : "";
443
444 // Use a poor-man's text templating mechanism; first populate the full header
445 // with placeholder tokens, and then rewrite the tokens with real values.
446 *header =
447 R"(// Generated by tfcompile, the TensorFlow graph compiler. DO NOT EDIT!
448 //
449 // This header was generated via ahead-of-time compilation of a TensorFlow
450 // graph. An object file corresponding to this header was also generated.
451 // This header gives access to the functionality in that object file.
452 //
453 // clang-format off
454
455 #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
456 #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard)
457
458 {{INCLUDE_XLA_DATA_PROTO}}
459 {{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
460 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
461 #include "tensorflow/core/platform/types.h"
462
463 namespace Eigen { struct ThreadPoolDevice; }
464 namespace xla { class ExecutableRunOptions; }
465
466 // (Implementation detail) Entry point to the function in the object file.
467 extern "C" void {{ENTRY}}(
468 void* result, const ::xla::ExecutableRunOptions* run_options,
469 const void** args, void** temps, tensorflow::int64* profile_counters);
470
471 {{DECLS_FROM_OBJ_FILE}}
472
473 {{NS_START}}
474 // {{CLASS}} represents a computation previously specified in a
475 // TensorFlow graph, now compiled into executable code. This extends the generic
476 // XlaCompiledCpuFunction class with statically type-safe arg and result
477 // methods. Usage example:
478 //
479 // {{CLASS}} computation;
480 // // ...set args using computation.argN methods
481 // CHECK(computation.Run());
482 // // ...inspect results using computation.resultN methods
483 //
484 // The Run method invokes the actual computation, with inputs read from arg
485 // buffers, and outputs written to result buffers. Each Run call may also use
486 // a set of temporary buffers for the computation.
487 //
488 // By default each instance of this class manages its own arg, result and temp
489 // buffers. The AllocMode constructor parameter may be used to modify the
490 // buffer allocation strategy.
491 //
492 // Under the default allocation strategy, this class is thread-compatible:
493 // o Calls to non-const methods require exclusive access to the object.
494 // o Concurrent calls to const methods are OK, if those calls are made while it
495 // is guaranteed that no thread may call a non-const method.
496 //
497 // The logical function signature is:
498 // {{PROGRAM_SHAPE}}
499 //
500 // Memory stats:
501 // arg bytes total: {{ARG_BYTES_TOTAL}}
502 // arg bytes aligned: {{ARG_BYTES_ALIGNED}}
503 // temp bytes total: {{TEMP_BYTES_TOTAL}}
504 // temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
505 class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
506 public:
507 // Number of input arguments for the compiled computation.
508 static constexpr size_t kNumArgs = {{ARG_NUM}};
509
510 // Byte size of each argument buffer. There are kNumArgs entries.
511 static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
512 return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
513 }
514
515 // Returns static data used to create an XlaCompiledCpuFunction.
516 static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
517 static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
518 XlaCompiledCpuFunction::StaticData* data =
519 new XlaCompiledCpuFunction::StaticData;
520 set_static_data_raw_function(data, {{ENTRY}});
521 set_static_data_buffer_infos(data, BufferInfos());
522 set_static_data_num_buffers(data, kNumBuffers);
523 set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
524 set_static_data_num_args(data, kNumArgs);
525 set_static_data_result_index(data, kResultIndex);
526 set_static_data_arg_names(data, StaticArgNames());
527 set_static_data_result_names(data, StaticResultNames());
528 set_static_data_program_shape(data, StaticProgramShape());
529 set_static_data_hlo_profile_printer_data(
530 data, StaticHloProfilePrinterData());
531 {{ASSIGN_PROFILE_COUNTERS_SIZE}}
532 return data;
533 }();
534 return *kStaticData;
535 }
536
537 {{CLASS}}(AllocMode alloc_mode =
538 AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
539 : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
540
541 {{CLASS}}(const {{CLASS}}&) = delete;
542 {{CLASS}}& operator=(const {{CLASS}}&) = delete;
543
544 // Arg methods for managing input buffers. Buffers are in row-major order.
545 // There is a set of methods for each positional argument, with the following
546 // general form:
547 //
548 // void set_argN_data(void* data)
549 // Sets the buffer of type T for positional argument N. May be called in
550 // any AllocMode. Must be called before Run to have an affect. Must be
551 // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
552 // argument, to set the argument buffers.
553 //
554 // T* argN_data()
555 // Returns the buffer of type T for positional argument N.
556 //
557 // T& argN(...dim indices...)
558 // Returns a reference to the value of type T for positional argument N,
559 // with dim indices specifying which value. No bounds checking is performed
560 // on dim indices.
561 {{METHODS_ARG}}
562
563 // Result methods for managing output buffers. Buffers are in row-major order.
564 // Must only be called after a successful Run call. There is a set of methods
565 // for each positional result, with the following general form:
566 //
567 // T* resultN_data()
568 // Returns the buffer of type T for positional result N.
569 //
570 // T& resultN(...dim indices...)
571 // Returns a reference to the value of type T for positional result N,
572 // with dim indices specifying which value. No bounds checking is performed
573 // on dim indices.
574 //
575 // Unlike the arg methods, there is no set_resultN_data method. The result
576 // buffers are managed internally, and may change after each call to Run.
577 {{METHODS_RESULT}}
578
579 // Methods for managing variable buffers. Buffers are in row-major order.
580 //
581 // For read-write variables we generate the following methods:
582 //
583 // void set_var_X_data(T* data)
584 // Sets the buffer for variable X. Must be called before Run if the
585 // allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
586 //
587 // T* var_X_data()
588 // Returns the buffer of type T for variable X. If the allocation mode is
589 // RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
590 // buffer passed to set_var_X_data.
591 //
592 // T& var_X(...dim indices...)
593 // Returns a reference to the value of type T for variable X,
594 // with dim indices specifying which value. No bounds checking is performed
595 // on dim indices.
596 //
597 // For readonly variables we generate the same set of methods, except that we
598 // use `const T` instead of `T`. We use `const T` to avoid erasing the
599 // constness of the buffer passed to `set_var_X_data` but the underlying
600 // buffer is not const (and thus the const can be safely const-cast'ed away)
601 // unless `set_var_X_data` is called with a pointer to constant storage.
602 {{METHODS_VARIABLE}}
603
604 private:
605 // Number of buffers for the compiled computation.
606 static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
607
608 static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
609 static const ::xla::cpu_function_runtime::BufferInfo
610 kBufferInfos[kNumBuffers] = {
611 {{BUFFER_INFOS_AS_STRING}}
612 };
613 return kBufferInfos;
614 }
615
616 static const ::tensorflow::int32* ArgIndexToBufferIndex() {
617 static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
618 {{ARG_INDEX_TABLE}}
619 };
620 return kArgIndexToBufferIndex;
621 }
622
623 // The 0-based index of the result tuple in the temporary buffers.
624 static constexpr size_t kResultIndex = {{RESULT_INDEX}};
625
626 // Array of names of each positional argument, terminated by nullptr.
627 static const char** StaticArgNames() {{ARG_NAMES_CODE}}
628
629 // Array of names of each positional result, terminated by nullptr.
630 static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
631
632 // Shape of the args and results.
633 static const ::xla::ProgramShapeProto* StaticProgramShape() {
634 static const ::xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
635 return kShape;
636 }
637
638 // Metadata that can be used to pretty-print profile counters.
639 static const ::xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
640 static const ::xla::HloProfilePrinterData* kHloProfilePrinterData =
641 {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
642 return kHloProfilePrinterData;
643 }
644 };
645 {{NS_END}}
646
647 #endif // TFCOMPILE_GENERATED_{{ENTRY}}_H_
648
649 // clang-format on
650 )";
651 // The replacement strategy is naive, but good enough for our purposes.
652 const std::vector<std::pair<string, string>> rewrites = {
653 {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
654 {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
655 {"{{ARG_NAMES_CODE}}", arg_names_code},
656 {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
657 {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
658 {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
659 {"{{CLASS}}", opts.class_name},
660 {"{{DECLS_FROM_OBJ_FILE}}",
661 absl::StrJoin(metadata_result.header_variable_decls, "\n")},
662 {"{{ENTRY}}", compile_result.entry_point},
663 {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
664 metadata_result.hlo_profile_printer_data_access_shim},
665 {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
666 {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
667 include_hlo_profile_printer_data_proto},
668 {"{{METHODS_ARG}}\n", methods_arg},
669 {"{{METHODS_RESULT}}\n", methods_result},
670 {"{{METHODS_VARIABLE}}\n", methods_variable},
671 {"{{NS_END}}\n", ns_end},
672 {"{{NS_START}}\n", ns_start},
673 {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
674 {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
675 metadata_result.program_shape_access_shim},
676 {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
677 {"{{RESULT_NAMES_CODE}}", result_names_code},
678 {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
679 {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
680 {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
681 {"{{BUFFER_INFOS_AS_STRING}}",
682 absl::StrJoin(buffer_infos_as_strings, ",\n")}};
683 absl::StrReplaceAll(rewrites, header);
684 return Status::OK();
685 }
686
CreateUniqueIdentifier(const CodegenOpts & opts,absl::string_view suffix)687 static string CreateUniqueIdentifier(const CodegenOpts& opts,
688 absl::string_view suffix) {
689 string result = "__tfcompile";
690 for (const string& n : opts.namespaces) {
691 absl::StrAppend(&result, "_", n);
692 }
693
694 absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
695 return result;
696 }
697
GenerateMetadata(const CodegenOpts & opts,const CompileResult & compile_result,MetadataResult * metadata_result)698 Status GenerateMetadata(const CodegenOpts& opts,
699 const CompileResult& compile_result,
700 MetadataResult* metadata_result) {
701 std::unique_ptr<xla::ProgramShapeProto> program_shape;
702
703 if (opts.gen_program_shape) {
704 program_shape =
705 absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
706
707 // The parameter names are currently meaningless, and redundant with the
708 // rest of our metadata, so clear them out to avoid confusion and save
709 // space.
710 program_shape->clear_parameter_names();
711 }
712
713 // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
714 // a shim that evaluates to nullptr, which is what we want.
715
716 ProtobufToEmbed program_shape_protobuf{
717 CreateUniqueIdentifier(opts, "ProgramShapeProto"),
718 "::xla::ProgramShapeProto", program_shape.get()};
719
720 ProtobufToEmbed hlo_profile_printer_data_protobuf{
721 CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
722 "::xla::HloProfilePrinterData",
723 compile_result.aot->hlo_profile_printer_data()};
724
725 TF_ASSIGN_OR_RETURN(
726 EmbeddedProtocolBuffers embedded_protobufs,
727 CreateEmbeddedProtocolBuffers(
728 opts.target_triple,
729 {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
730
731 metadata_result->program_shape_access_shim =
732 std::move(embedded_protobufs.cpp_shims[0].expression);
733 metadata_result->hlo_profile_printer_data_access_shim =
734 std::move(embedded_protobufs.cpp_shims[1].expression);
735 metadata_result->header_variable_decls.emplace_back(
736 std::move(embedded_protobufs.cpp_shims[0].variable_decl));
737 metadata_result->header_variable_decls.emplace_back(
738 std::move(embedded_protobufs.cpp_shims[1].variable_decl));
739 metadata_result->object_file_data =
740 std::move(embedded_protobufs.object_file_data);
741 return Status::OK();
742 }
743
ParseCppClass(const string & cpp_class,string * class_name,std::vector<string> * namespaces)744 Status ParseCppClass(const string& cpp_class, string* class_name,
745 std::vector<string>* namespaces) {
746 class_name->clear();
747 namespaces->clear();
748 if (cpp_class.empty()) {
749 return errors::InvalidArgument("empty cpp_class: " + cpp_class);
750 }
751 std::vector<string> parts = absl::StrSplit(cpp_class, "::");
752 if (parts.front().empty()) {
753 // Allow a fully qualified name that starts with "::".
754 parts.erase(parts.begin());
755 }
756 for (int i = 0; i < parts.size(); ++i) {
757 if (i < parts.size() - 1) {
758 TF_RETURN_IF_ERROR(ValidateCppIdent(
759 parts[i], "in namespace component of cpp_class: " + cpp_class));
760 namespaces->push_back(parts[i]);
761 } else {
762 TF_RETURN_IF_ERROR(ValidateCppIdent(
763 parts[i], "in class name of cpp_class: " + cpp_class));
764 *class_name = parts[i];
765 }
766 }
767 return Status::OK();
768 }
769
ValidateCppIdent(absl::string_view ident,absl::string_view msg)770 Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
771 if (ident.empty()) {
772 return errors::InvalidArgument("empty identifier: ", msg);
773 }
774 // Require that the identifier starts with a nondigit, and is composed of
775 // nondigits and digits, as specified in section [2.11 Identifiers] of the
776 // C++11 Standard. Note that nondigit is defined as [_a-zA-Z] and digit is
777 // defined as [0-9].
778 //
779 // Technically the standard also allows for `universal-character-name`, with a
780 // table of allowed unicode ranges, as well as `other implementation-defined
781 // characters`. We disallow those here to give better error messages, at the
782 // expensive of being more restrictive than the standard.
783 if (ident[0] != '_' && !IsAlpha(ident[0])) {
784 return errors::InvalidArgument("illegal leading char: ", msg);
785 }
786 for (size_t pos = 1; pos < ident.size(); ++pos) {
787 if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
788 return errors::InvalidArgument("illegal char: ", msg);
789 }
790 }
791 return Status::OK();
792 }
793
794 } // namespace tfcompile
795 } // namespace tensorflow
796