• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/aot/codegen.h"
17 
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_join.h"
25 #include "absl/strings/str_replace.h"
26 #include "absl/strings/str_split.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/aot/embedded_protocol_buffers.h"
29 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
30 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
31 #include "tensorflow/compiler/xla/service/compiler.h"
32 #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/lib/core/errors.h"
36 
37 namespace tensorflow {
38 namespace tfcompile {
39 
40 namespace {
41 
42 using BufferInfo = xla::cpu_function_runtime::BufferInfo;
43 
IsAlpha(char c)44 bool IsAlpha(char c) {
45   return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z');
46 }
47 
IsAlphaNum(char c)48 bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
49 
50 // Convert an XLA type into a C++ type.
XLATypeToCpp(xla::PrimitiveType type,string * str)51 Status XLATypeToCpp(xla::PrimitiveType type, string* str) {
52   switch (type) {
53     case xla::PRED:
54       *str = "bool";
55       break;
56     case xla::S8:
57       *str = "tensorflow::int8";
58       break;
59     case xla::S16:
60       *str = "tensorflow::int16";
61       break;
62     case xla::S32:
63       *str = "tensorflow::int32";
64       break;
65     case xla::S64:
66       *str = "tensorflow::int64";
67       break;
68     case xla::U8:
69       *str = "tensorflow::uint8";
70       break;
71     case xla::U16:
72       *str = "tensorflow::uint16";
73       break;
74     case xla::U32:
75       *str = "tensorflow::uint32";
76       break;
77     case xla::U64:
78       *str = "tensorflow::uint64";
79       break;
80     case xla::F32:
81       *str = "float";
82       break;
83     case xla::F64:
84       *str = "double";
85       break;
86     default:
87       return errors::Unimplemented("XLA type ", xla::PrimitiveType_Name(type),
88                                    " has no equivalent in C++");
89   }
90   return Status::OK();
91 }
92 
93 // Returns the sum of the size of each buffer in `buffer_infos`.
TotalBufferBytes(const std::vector<BufferInfo> & buffer_infos)94 size_t TotalBufferBytes(const std::vector<BufferInfo>& buffer_infos) {
95   return std::accumulate(buffer_infos.begin(), buffer_infos.end(), size_t{0},
96                          [](size_t size, const BufferInfo& buffer_info) {
97                            return size + buffer_info.size();
98                          });
99 }
100 
101 // Returns a vector of BufferInfo instances in `buffer_infos` that are entry
102 // parameter buffers.
ExtractEntryParamBufferInfos(const std::vector<BufferInfo> & buffer_infos)103 std::vector<BufferInfo> ExtractEntryParamBufferInfos(
104     const std::vector<BufferInfo>& buffer_infos) {
105   std::vector<BufferInfo> result;
106   std::copy_if(buffer_infos.begin(), buffer_infos.end(),
107                std::back_inserter(result), [](const BufferInfo& buffer_info) {
108                  return buffer_info.is_entry_parameter();
109                });
110   return result;
111 }
112 
113 // Returns a vector of BufferInfo instances in `buffer_infos` that are temp
114 // buffers.
ExtractTempBufferInfos(const std::vector<BufferInfo> & buffer_infos)115 std::vector<BufferInfo> ExtractTempBufferInfos(
116     const std::vector<BufferInfo>& buffer_infos) {
117   std::vector<BufferInfo> result;
118   std::copy_if(buffer_infos.begin(), buffer_infos.end(),
119                std::back_inserter(result), [](const BufferInfo& buffer_info) {
120                  return buffer_info.is_temp_buffer();
121                });
122   return result;
123 }
124 
125 // Add (from,to) rewrite pairs based on the given shape.  These rewrite pairs
126 // are used to generate methods for args and results.
AddRewritesForShape(int i,const xla::Shape & shape,std::vector<std::pair<string,string>> * rewrites)127 Status AddRewritesForShape(int i, const xla::Shape& shape,
128                            std::vector<std::pair<string, string>>* rewrites) {
129   string type;
130   TF_RETURN_IF_ERROR(XLATypeToCpp(shape.element_type(), &type));
131   std::vector<string> dim_vars;
132   string dim_sizes, indices;
133   if (shape.rank() == 0 ||
134       (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) {
135     dim_sizes = "[1]";
136     indices = "[0]";
137   } else {
138     for (int dim = 0; dim < shape.dimensions_size(); ++dim) {
139       dim_vars.push_back(absl::StrCat("size_t dim", dim));
140       dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]");
141       indices += absl::StrCat("[dim", dim, "]");
142     }
143   }
144   rewrites->push_back({"{{I}}", absl::StrCat(i)});
145   rewrites->push_back({"{{TYPE}}", type});
146   rewrites->push_back({"{{DIM_VARS}}", absl::StrJoin(dim_vars, ", ")});
147   rewrites->push_back({"{{DIM_SIZES}}", dim_sizes});
148   rewrites->push_back({"{{INDICES}}", indices});
149   return Status::OK();
150 }
151 
152 // Returns code rewritten by replacing all rewrite pairs, with an extra rewrite
153 // for the name.  Note that the rewriting strategy is roughly O(N*M), where N is
154 // the size of the code and M is the number of rewrites.  It's fine for now
155 // since N and M are pretty small.
156 //
157 // TODO(toddw): If this becomes a problem, we should be able to change the
158 // algorithm to O(N) by using a state machine, e.g. regexps or a real
159 // text-templating mechanism.
RewriteWithName(const string & name,string code,const std::vector<std::pair<string,string>> & rewrites)160 string RewriteWithName(const string& name, string code,
161                        const std::vector<std::pair<string, string>>& rewrites) {
162   absl::StrReplaceAll(rewrites, &code);
163   absl::StrReplaceAll({{"{{NAME}}", name}}, &code);
164   return code;
165 }
166 
167 // Generate methods for args (inputs).
GenArgMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,const CompileResult & compile_result,string * methods)168 Status GenArgMethods(const tf2xla::Config& config,
169                      const xla::ProgramShapeProto& ps,
170                      const CompileResult& compile_result, string* methods) {
171   size_t num_args = ps.parameters_size();
172   if (config.feed_size() + config.variable_size() != num_args) {
173     return errors::InvalidArgument(
174         "mismatch between feed_size(", config.feed_size(), ")+variable_size(",
175         config.variable_size(), ") and num_args(", num_args, ")");
176   }
177   for (int i = 0; i < config.feed_size(); ++i) {
178     std::vector<std::pair<string, string>> rewrites;
179     TF_RETURN_IF_ERROR(
180         AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
181     const string code = R"(
182   void set_arg{{NAME}}_data(const void* data) {
183     set_arg_data({{I}}, data);
184   }
185   {{TYPE}}* arg{{NAME}}_data() {
186     return static_cast<{{TYPE}}*>(arg_data({{I}}));
187   }
188   {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) {
189     return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
190         arg_data({{I}}))){{INDICES}};
191   }
192   const {{TYPE}}* arg{{NAME}}_data() const {
193     return static_cast<const {{TYPE}}*>(arg_data({{I}}));
194   }
195   const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const {
196     return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
197         arg_data({{I}}))){{INDICES}};
198   }
199 )";
200     *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
201     if (!config.feed(i).name().empty()) {
202       *methods += RewriteWithName("_" + config.feed(i).name(), code, rewrites);
203     }
204   }
205   return Status::OK();
206 }
207 
208 // Generate methods for results (outputs).
GenResultMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)209 Status GenResultMethods(const tf2xla::Config& config,
210                         const xla::ProgramShapeProto& ps, string* methods) {
211   if (ps.result().element_type() != xla::TUPLE) {
212     // The XlaCompiler we use to build the xla computation always generates a
213     // tuple result, and we rely on this to simplify code generation.
214     return errors::Internal("codegen requires the XLA result to be a tuple");
215   }
216   size_t num_results = ps.result().tuple_shapes_size();
217   int readonly_variables = absl::c_count_if(
218       config.variable(),
219       [](const tf2xla::Variable& var) { return var.readonly(); });
220   if (config.fetch_size() + config.variable_size() - readonly_variables !=
221       num_results) {
222     return errors::InvalidArgument("mismatch between fetch_size(",
223                                    config.fetch_size(), ")+variable_size(",
224                                    config.variable_size(), ") and tuple_size(",
225                                    ps.result().tuple_shapes_size(), ")");
226   }
227   for (int i = 0; i < config.fetch_size(); ++i) {
228     std::vector<std::pair<string, string>> rewrites;
229     TF_RETURN_IF_ERROR(AddRewritesForShape(
230         i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
231     string code = R"(
232   {{TYPE}}* result{{NAME}}_data() {
233     return static_cast<{{TYPE}}*>(result_data({{I}}));
234   }
235   {{TYPE}}& result{{NAME}}({{DIM_VARS}}) {
236     return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>(
237         result_data({{I}}))){{INDICES}};
238   }
239   const {{TYPE}}* result{{NAME}}_data() const {
240     return static_cast<const {{TYPE}}*>(result_data({{I}}));
241   }
242   const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const {
243     return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
244         result_data({{I}}))){{INDICES}};
245   }
246 )";
247     *methods += RewriteWithName(absl::StrCat(i), code, rewrites);
248     if (!config.fetch(i).name().empty()) {
249       *methods += RewriteWithName("_" + config.fetch(i).name(), code, rewrites);
250     }
251   }
252   return Status::OK();
253 }
254 
255 // Generate methods for variables.
GenVariableMethods(const tf2xla::Config & config,const xla::ProgramShapeProto & ps,string * methods)256 Status GenVariableMethods(const tf2xla::Config& config,
257                           const xla::ProgramShapeProto& ps, string* methods) {
258   size_t num_args = ps.parameters_size();
259   for (int i = config.feed_size(); i < num_args; ++i) {
260     std::vector<std::pair<string, string>> rewrites;
261     TF_RETURN_IF_ERROR(
262         AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
263     const string code = R"(
264   void set_var_{{NAME}}_data({{MAYBE_CONST}}{{TYPE}}* data) {
265     set_arg_data({{I}}, data);
266   }
267   {{MAYBE_CONST}}{{TYPE}}* var_{{NAME}}_data() {
268     return static_cast<{{MAYBE_CONST}}{{TYPE}}*>(arg_data({{I}}));
269   }
270   {{MAYBE_CONST}}{{TYPE}}& var_{{NAME}}({{DIM_VARS}}) {
271     return (*static_cast<{{MAYBE_CONST}}{{TYPE}}(*){{DIM_SIZES}}>(
272         arg_data({{I}}))){{INDICES}};
273   }
274   const {{TYPE}}* var_{{NAME}}_data() const {
275     return static_cast<const {{TYPE}}*>(arg_data({{I}}));
276   }
277   const {{TYPE}}& var_{{NAME}}({{DIM_VARS}}) const {
278     return (*static_cast<const {{TYPE}}(*){{DIM_SIZES}}>(
279         arg_data({{I}}))){{INDICES}};
280   }
281 )";
282     const tf2xla::Variable& var = config.variable(i - config.feed_size());
283     rewrites.emplace_back("{{MAYBE_CONST}}", var.readonly() ? "const " : "");
284     *methods += RewriteWithName(
285         var.name().empty() ? var.node_name() : var.name(), code, rewrites);
286   }
287   return Status::OK();
288 }
289 
290 // Generates code implementing {Arg,Result}Names(), where T is one of
291 // tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string
292 // literal in the array, with nullptr terminating the array.
293 template <typename T>
GenNameToIndexCode(const T & entries,bool generate)294 string GenNameToIndexCode(const T& entries, bool generate) {
295   // No need for a static array if we're not supposed to generate the data.
296   if (!generate) {
297     return "{\n    return nullptr;\n  }";
298   }
299   // Determine when to stop. We stop emitting string literals after the last
300   // non-empty name.
301   int end = entries.size();
302   for (int i = entries.size() - 1; i >= 0; --i) {
303     if (!entries[i].name().empty()) {
304       break;
305     }
306     end = i;
307   }
308   // Emit string literals up to the last non-empty name.
309   string code = "{\n    static const char* kNames[] = {";
310   for (int i = 0; i < end; ++i) {
311     if (i > 0) {
312       code += ", ";
313     }
314     code += "\"";
315     code += entries[i].name();
316     code += "\"";
317   }
318   if (end > 0) {
319     code += ", ";
320   }
321   code += "nullptr};\n    return kNames;\n  }";
322   return code;
323 }
324 
ValidateFeedFetchCppNames(const tf2xla::Config & config)325 Status ValidateFeedFetchCppNames(const tf2xla::Config& config) {
326   for (const tf2xla::Feed& feed : config.feed()) {
327     if (!feed.name().empty()) {
328       TF_RETURN_IF_ERROR(ValidateCppIdent(feed.name(), "feed name"));
329     }
330   }
331   for (const tf2xla::Fetch& fetch : config.fetch()) {
332     if (!fetch.name().empty()) {
333       TF_RETURN_IF_ERROR(ValidateCppIdent(fetch.name(), "fetch name"));
334     }
335   }
336   for (const tf2xla::Variable& variable : config.variable()) {
337     if (!variable.name().empty()) {
338       TF_RETURN_IF_ERROR(ValidateCppIdent(variable.name(), "variable name"));
339     } else {
340       TF_RETURN_IF_ERROR(
341           ValidateCppIdent(variable.node_name(), "variable name"));
342     }
343   }
344   return Status::OK();
345 }
346 
347 // Returns a list of C++ expressions that, when executed, will construct the
348 // BufferInfo instances in `buffer_infos`.
BufferInfosToCppExpression(const std::vector<BufferInfo> & buffer_infos)349 std::vector<string> BufferInfosToCppExpression(
350     const std::vector<BufferInfo>& buffer_infos) {
351   std::vector<string> buffer_infos_as_strings;
352   std::transform(buffer_infos.begin(), buffer_infos.end(),
353                  std::back_inserter(buffer_infos_as_strings),
354                  [](const BufferInfo& buffer_info) {
355                    std::pair<uint64, uint64> encoded = buffer_info.Encode();
356                    string encoded_second_as_str =
357                        encoded.second == ~0ULL
358                            ? "~0ULL"
359                            : absl::StrCat(encoded.second, "ULL");
360                    return absl::StrCat(
361                        "::xla::cpu_function_runtime::BufferInfo({",
362                        encoded.first, "ULL, ", encoded_second_as_str, "})");
363                  });
364   return buffer_infos_as_strings;
365 }
366 }  // namespace
367 
GenerateHeader(const CodegenOpts & opts,const tf2xla::Config & config,const CompileResult & compile_result,const MetadataResult & metadata_result,string * header)368 Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
369                       const CompileResult& compile_result,
370                       const MetadataResult& metadata_result, string* header) {
371   TF_RETURN_IF_ERROR(ValidateConfig(config));
372   TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config));
373   const int64 result_index = compile_result.aot->result_buffer_index();
374   const std::vector<BufferInfo>& buffer_infos =
375       compile_result.aot->buffer_infos();
376   const std::vector<int32> arg_index_table =
377       ::xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos);
378   std::vector<string> buffer_infos_as_strings =
379       BufferInfosToCppExpression(buffer_infos);
380   if (result_index < 0 || result_index >= buffer_infos.size()) {
381     return errors::InvalidArgument("result index: ", result_index,
382                                    " is outside the range of temp sizes: [0,",
383                                    buffer_infos.size(), ")");
384   }
385 
386   // Compute sizes and generate methods.
387   std::vector<BufferInfo> buffer_infos_for_args =
388       ExtractEntryParamBufferInfos(buffer_infos);
389   std::vector<BufferInfo> buffer_infos_for_temps =
390       ExtractTempBufferInfos(buffer_infos);
391   const xla::ProgramShapeProto& ps = compile_result.program_shape;
392   string methods_arg, methods_result, methods_variable;
393   TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
394   TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
395   TF_RETURN_IF_ERROR(GenVariableMethods(config, ps, &methods_variable));
396   const size_t arg_bytes_aligned =
397       xla::cpu_function_runtime::AlignedBufferBytes(
398           buffer_infos_for_args.data(), buffer_infos_for_args.size(),
399           /*allocate_entry_params=*/true);
400   const size_t arg_bytes_total = TotalBufferBytes(buffer_infos_for_args);
401   const size_t temp_bytes_aligned =
402       xla::cpu_function_runtime::AlignedBufferBytes(
403           buffer_infos_for_temps.data(), buffer_infos_for_temps.size(),
404           /*allocate_entry_params=*/true);
405   const size_t temp_bytes_total = TotalBufferBytes(buffer_infos_for_temps);
406 
407   // Create rewrite strings for namespace start and end.
408   string ns_start;
409   for (const string& n : opts.namespaces) {
410     ns_start += absl::StrCat("namespace ", n, " {\n");
411   }
412   ns_start += "\n";
413   string ns_end("\n");
414   for (int i = opts.namespaces.size() - 1; i >= 0; --i) {
415     const string& n = opts.namespaces[i];
416     ns_end += absl::StrCat("}  // end namespace ", n, "\n");
417   }
418 
419   // Generate metadata.
420   const string arg_names_code =
421       GenNameToIndexCode(config.feed(), opts.gen_name_to_index);
422   const string result_names_code =
423       GenNameToIndexCode(config.fetch(), opts.gen_name_to_index);
424   const string include_xla_data_proto =
425       opts.gen_program_shape
426           ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")"
427           : "";
428 
429   const string include_hlo_profile_printer_data_proto =
430       opts.gen_hlo_profile_printer_data
431           ? R"(#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h")"
432           : "";
433 
434   // When HLO profiling is disabled we only forward declare the
435   // HloProfilePrinter protobuf.  So we can only conditionally emit this code
436   // calling HloProfilePrinter::profile_counters_size.
437   const string assign_profile_counters_size =
438       opts.gen_hlo_profile_printer_data
439           ? "set_static_data_profile_counters_size(data, "
440             "get_static_data_hlo_profile_printer_data(data)->"
441             "profile_counters_size());"
442           : "";
443 
444   // Use a poor-man's text templating mechanism; first populate the full header
445   // with placeholder tokens, and then rewrite the tokens with real values.
446   *header =
447       R"(// Generated by tfcompile, the TensorFlow graph compiler.  DO NOT EDIT!
448 //
449 // This header was generated via ahead-of-time compilation of a TensorFlow
450 // graph.  An object file corresponding to this header was also generated.
451 // This header gives access to the functionality in that object file.
452 //
453 // clang-format off
454 
455 #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_  // NOLINT(build/header_guard)
456 #define TFCOMPILE_GENERATED_{{ENTRY}}_H_  // NOLINT(build/header_guard)
457 
458 {{INCLUDE_XLA_DATA_PROTO}}
459 {{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}
460 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
461 #include "tensorflow/core/platform/types.h"
462 
463 namespace Eigen { struct ThreadPoolDevice; }
464 namespace xla { class ExecutableRunOptions; }
465 
466 // (Implementation detail) Entry point to the function in the object file.
467 extern "C" void {{ENTRY}}(
468     void* result, const ::xla::ExecutableRunOptions* run_options,
469     const void** args, void** temps, tensorflow::int64* profile_counters);
470 
471 {{DECLS_FROM_OBJ_FILE}}
472 
473 {{NS_START}}
474 // {{CLASS}} represents a computation previously specified in a
475 // TensorFlow graph, now compiled into executable code. This extends the generic
476 // XlaCompiledCpuFunction class with statically type-safe arg and result
477 // methods. Usage example:
478 //
479 //   {{CLASS}} computation;
480 //   // ...set args using computation.argN methods
481 //   CHECK(computation.Run());
482 //   // ...inspect results using computation.resultN methods
483 //
484 // The Run method invokes the actual computation, with inputs read from arg
485 // buffers, and outputs written to result buffers. Each Run call may also use
486 // a set of temporary buffers for the computation.
487 //
488 // By default each instance of this class manages its own arg, result and temp
489 // buffers. The AllocMode constructor parameter may be used to modify the
490 // buffer allocation strategy.
491 //
492 // Under the default allocation strategy, this class is thread-compatible:
493 // o Calls to non-const methods require exclusive access to the object.
494 // o Concurrent calls to const methods are OK, if those calls are made while it
495 //   is guaranteed that no thread may call a non-const method.
496 //
497 // The logical function signature is:
498 //   {{PROGRAM_SHAPE}}
499 //
500 // Memory stats:
501 //   arg bytes total:    {{ARG_BYTES_TOTAL}}
502 //   arg bytes aligned:  {{ARG_BYTES_ALIGNED}}
503 //   temp bytes total:   {{TEMP_BYTES_TOTAL}}
504 //   temp bytes aligned: {{TEMP_BYTES_ALIGNED}}
505 class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction {
506  public:
507   // Number of input arguments for the compiled computation.
508   static constexpr size_t kNumArgs = {{ARG_NUM}};
509 
510   // Byte size of each argument buffer. There are kNumArgs entries.
511   static const ::tensorflow::int64 ArgSize(::tensorflow::int32 index) {
512     return BufferInfos()[ArgIndexToBufferIndex()[index]].size();
513   }
514 
515   // Returns static data used to create an XlaCompiledCpuFunction.
516   static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() {
517     static XlaCompiledCpuFunction::StaticData* kStaticData = [](){
518       XlaCompiledCpuFunction::StaticData* data =
519         new XlaCompiledCpuFunction::StaticData;
520       set_static_data_raw_function(data, {{ENTRY}});
521       set_static_data_buffer_infos(data, BufferInfos());
522       set_static_data_num_buffers(data, kNumBuffers);
523       set_static_data_arg_index_table(data, ArgIndexToBufferIndex());
524       set_static_data_num_args(data, kNumArgs);
525       set_static_data_result_index(data, kResultIndex);
526       set_static_data_arg_names(data, StaticArgNames());
527       set_static_data_result_names(data, StaticResultNames());
528       set_static_data_program_shape(data, StaticProgramShape());
529       set_static_data_hlo_profile_printer_data(
530           data, StaticHloProfilePrinterData());
531 {{ASSIGN_PROFILE_COUNTERS_SIZE}}
532       return data;
533     }();
534     return *kStaticData;
535   }
536 
537   {{CLASS}}(AllocMode alloc_mode =
538             AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS)
539       : XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
540 
541   {{CLASS}}(const {{CLASS}}&) = delete;
542   {{CLASS}}& operator=(const {{CLASS}}&) = delete;
543 
544   // Arg methods for managing input buffers. Buffers are in row-major order.
545   // There is a set of methods for each positional argument, with the following
546   // general form:
547   //
548   // void set_argN_data(void* data)
549   //   Sets the buffer of type T for positional argument N. May be called in
550   //   any AllocMode. Must be called before Run to have an affect. Must be
551   //   called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
552   //   argument, to set the argument buffers.
553   //
554   // T* argN_data()
555   //   Returns the buffer of type T for positional argument N.
556   //
557   // T& argN(...dim indices...)
558   //   Returns a reference to the value of type T for positional argument N,
559   //   with dim indices specifying which value. No bounds checking is performed
560   //   on dim indices.
561 {{METHODS_ARG}}
562 
563   // Result methods for managing output buffers. Buffers are in row-major order.
564   // Must only be called after a successful Run call. There is a set of methods
565   // for each positional result, with the following general form:
566   //
567   // T* resultN_data()
568   //   Returns the buffer of type T for positional result N.
569   //
570   // T& resultN(...dim indices...)
571   //   Returns a reference to the value of type T for positional result N,
572   //   with dim indices specifying which value. No bounds checking is performed
573   //   on dim indices.
574   //
575   // Unlike the arg methods, there is no set_resultN_data method. The result
576   // buffers are managed internally, and may change after each call to Run.
577 {{METHODS_RESULT}}
578 
579   // Methods for managing variable buffers. Buffers are in row-major order.
580   //
581   // For read-write variables we generate the following methods:
582   //
583   // void set_var_X_data(T* data)
584   //   Sets the buffer for variable X.  Must be called before Run if the
585   //   allocation mode is RESULTS_PROFILES_AND_TEMPS_ONLY.
586   //
587   // T* var_X_data()
588   //   Returns the buffer of type T for variable X.  If the allocation mode is
589   //   RESULTS_PROFILES_AND_TEMPS_ONLY then this buffer is the same as the
590   //   buffer passed to set_var_X_data.
591   //
592   // T& var_X(...dim indices...)
593   //   Returns a reference to the value of type T for variable X,
594   //   with dim indices specifying which value. No bounds checking is performed
595   //   on dim indices.
596   //
597   // For readonly variables we generate the same set of methods, except that we
598   // use `const T` instead of `T`.  We use `const T` to avoid erasing the
599   // constness of the buffer passed to `set_var_X_data` but the underlying
600   // buffer is not const (and thus the const can be safely const-cast'ed away)
601   // unless `set_var_X_data` is called with a pointer to constant storage.
602 {{METHODS_VARIABLE}}
603 
604  private:
605   // Number of buffers for the compiled computation.
606   static constexpr size_t kNumBuffers = {{NUM_BUFFERS}};
607 
608   static const ::xla::cpu_function_runtime::BufferInfo* BufferInfos() {
609     static const ::xla::cpu_function_runtime::BufferInfo
610       kBufferInfos[kNumBuffers] = {
611 {{BUFFER_INFOS_AS_STRING}}
612       };
613     return kBufferInfos;
614   }
615 
616   static const ::tensorflow::int32* ArgIndexToBufferIndex() {
617     static constexpr ::tensorflow::int32 kArgIndexToBufferIndex[kNumArgs] = {
618 {{ARG_INDEX_TABLE}}
619     };
620     return kArgIndexToBufferIndex;
621   }
622 
623   // The 0-based index of the result tuple in the temporary buffers.
624   static constexpr size_t kResultIndex = {{RESULT_INDEX}};
625 
626   // Array of names of each positional argument, terminated by nullptr.
627   static const char** StaticArgNames() {{ARG_NAMES_CODE}}
628 
629   // Array of names of each positional result, terminated by nullptr.
630   static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
631 
632   // Shape of the args and results.
633   static const ::xla::ProgramShapeProto* StaticProgramShape() {
634     static const ::xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
635     return kShape;
636   }
637 
638   // Metadata that can be used to pretty-print profile counters.
639   static const ::xla::HloProfilePrinterData* StaticHloProfilePrinterData() {
640     static const ::xla::HloProfilePrinterData* kHloProfilePrinterData =
641       {{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}};
642     return kHloProfilePrinterData;
643   }
644 };
645 {{NS_END}}
646 
647 #endif  // TFCOMPILE_GENERATED_{{ENTRY}}_H_
648 
649 // clang-format on
650 )";
651   // The replacement strategy is naive, but good enough for our purposes.
652   const std::vector<std::pair<string, string>> rewrites = {
653       {"{{ARG_BYTES_ALIGNED}}", absl::StrCat(arg_bytes_aligned)},
654       {"{{ARG_BYTES_TOTAL}}", absl::StrCat(arg_bytes_total)},
655       {"{{ARG_NAMES_CODE}}", arg_names_code},
656       {"{{ARG_NUM}}", absl::StrCat(arg_index_table.size())},
657       {"{{ARG_INDEX_TABLE}}", absl::StrJoin(arg_index_table, ", ")},
658       {"{{ASSIGN_PROFILE_COUNTERS_SIZE}}", assign_profile_counters_size},
659       {"{{CLASS}}", opts.class_name},
660       {"{{DECLS_FROM_OBJ_FILE}}",
661        absl::StrJoin(metadata_result.header_variable_decls, "\n")},
662       {"{{ENTRY}}", compile_result.entry_point},
663       {"{{HLO_PROFILE_PRINTER_DATA_SHIM_EXPRESSION}}",
664        metadata_result.hlo_profile_printer_data_access_shim},
665       {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto},
666       {"{{INCLUDE_HLO_PROFILE_PRINTER_DATA_PROTO}}",
667        include_hlo_profile_printer_data_proto},
668       {"{{METHODS_ARG}}\n", methods_arg},
669       {"{{METHODS_RESULT}}\n", methods_result},
670       {"{{METHODS_VARIABLE}}\n", methods_variable},
671       {"{{NS_END}}\n", ns_end},
672       {"{{NS_START}}\n", ns_start},
673       {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
674       {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
675        metadata_result.program_shape_access_shim},
676       {"{{RESULT_INDEX}}", absl::StrCat(result_index)},
677       {"{{RESULT_NAMES_CODE}}", result_names_code},
678       {"{{TEMP_BYTES_ALIGNED}}", absl::StrCat(temp_bytes_aligned)},
679       {"{{TEMP_BYTES_TOTAL}}", absl::StrCat(temp_bytes_total)},
680       {"{{NUM_BUFFERS}}", absl::StrCat(buffer_infos.size())},
681       {"{{BUFFER_INFOS_AS_STRING}}",
682        absl::StrJoin(buffer_infos_as_strings, ",\n")}};
683   absl::StrReplaceAll(rewrites, header);
684   return Status::OK();
685 }
686 
CreateUniqueIdentifier(const CodegenOpts & opts,absl::string_view suffix)687 static string CreateUniqueIdentifier(const CodegenOpts& opts,
688                                      absl::string_view suffix) {
689   string result = "__tfcompile";
690   for (const string& n : opts.namespaces) {
691     absl::StrAppend(&result, "_", n);
692   }
693 
694   absl::StrAppend(&result, "_", opts.class_name, "_", suffix);
695   return result;
696 }
697 
GenerateMetadata(const CodegenOpts & opts,const CompileResult & compile_result,MetadataResult * metadata_result)698 Status GenerateMetadata(const CodegenOpts& opts,
699                         const CompileResult& compile_result,
700                         MetadataResult* metadata_result) {
701   std::unique_ptr<xla::ProgramShapeProto> program_shape;
702 
703   if (opts.gen_program_shape) {
704     program_shape =
705         absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
706 
707     // The parameter names are currently meaningless, and redundant with the
708     // rest of our metadata, so clear them out to avoid confusion and save
709     // space.
710     program_shape->clear_parameter_names();
711   }
712 
713   // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives
714   // a shim that evaluates to nullptr, which is what we want.
715 
716   ProtobufToEmbed program_shape_protobuf{
717       CreateUniqueIdentifier(opts, "ProgramShapeProto"),
718       "::xla::ProgramShapeProto", program_shape.get()};
719 
720   ProtobufToEmbed hlo_profile_printer_data_protobuf{
721       CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
722       "::xla::HloProfilePrinterData",
723       compile_result.aot->hlo_profile_printer_data()};
724 
725   TF_ASSIGN_OR_RETURN(
726       EmbeddedProtocolBuffers embedded_protobufs,
727       CreateEmbeddedProtocolBuffers(
728           opts.target_triple,
729           {program_shape_protobuf, hlo_profile_printer_data_protobuf}));
730 
731   metadata_result->program_shape_access_shim =
732       std::move(embedded_protobufs.cpp_shims[0].expression);
733   metadata_result->hlo_profile_printer_data_access_shim =
734       std::move(embedded_protobufs.cpp_shims[1].expression);
735   metadata_result->header_variable_decls.emplace_back(
736       std::move(embedded_protobufs.cpp_shims[0].variable_decl));
737   metadata_result->header_variable_decls.emplace_back(
738       std::move(embedded_protobufs.cpp_shims[1].variable_decl));
739   metadata_result->object_file_data =
740       std::move(embedded_protobufs.object_file_data);
741   return Status::OK();
742 }
743 
ParseCppClass(const string & cpp_class,string * class_name,std::vector<string> * namespaces)744 Status ParseCppClass(const string& cpp_class, string* class_name,
745                      std::vector<string>* namespaces) {
746   class_name->clear();
747   namespaces->clear();
748   if (cpp_class.empty()) {
749     return errors::InvalidArgument("empty cpp_class: " + cpp_class);
750   }
751   std::vector<string> parts = absl::StrSplit(cpp_class, "::");
752   if (parts.front().empty()) {
753     // Allow a fully qualified name that starts with "::".
754     parts.erase(parts.begin());
755   }
756   for (int i = 0; i < parts.size(); ++i) {
757     if (i < parts.size() - 1) {
758       TF_RETURN_IF_ERROR(ValidateCppIdent(
759           parts[i], "in namespace component of cpp_class: " + cpp_class));
760       namespaces->push_back(parts[i]);
761     } else {
762       TF_RETURN_IF_ERROR(ValidateCppIdent(
763           parts[i], "in class name of cpp_class: " + cpp_class));
764       *class_name = parts[i];
765     }
766   }
767   return Status::OK();
768 }
769 
ValidateCppIdent(absl::string_view ident,absl::string_view msg)770 Status ValidateCppIdent(absl::string_view ident, absl::string_view msg) {
771   if (ident.empty()) {
772     return errors::InvalidArgument("empty identifier: ", msg);
773   }
774   // Require that the identifier starts with a nondigit, and is composed of
775   // nondigits and digits, as specified in section [2.11 Identifiers] of the
776   // C++11 Standard.  Note that nondigit is defined as [_a-zA-Z] and digit is
777   // defined as [0-9].
778   //
779   // Technically the standard also allows for `universal-character-name`, with a
780   // table of allowed unicode ranges, as well as `other implementation-defined
781   // characters`.  We disallow those here to give better error messages, at the
782   // expensive of being more restrictive than the standard.
783   if (ident[0] != '_' && !IsAlpha(ident[0])) {
784     return errors::InvalidArgument("illegal leading char: ", msg);
785   }
786   for (size_t pos = 1; pos < ident.size(); ++pos) {
787     if (ident[pos] != '_' && !IsAlphaNum(ident[pos])) {
788       return errors::InvalidArgument("illegal char: ", msg);
789     }
790   }
791   return Status::OK();
792 }
793 
794 }  // namespace tfcompile
795 }  // namespace tensorflow
796