• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17 
18 #include <numeric>
19 
20 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_activity.pb.h"
26 #include "tensorflow/compiler/jit/xla_activity_listener.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
28 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
29 #include "tensorflow/compiler/tf2xla/shape_util.h"
30 #include "tensorflow/compiler/tf2xla/type_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
32 #include "tensorflow/compiler/tf2xla/xla_context.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/graph_optimizer.h"
39 #include "tensorflow/core/common_runtime/metrics.h"
40 #include "tensorflow/core/framework/attr_value_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/tpu/tpu_defs.h"
52 #include "tensorflow/core/util/dump_graph.h"
53 
54 namespace tensorflow {
55 
56 constexpr int64_t XlaCompilationCache::kDefaultCompilationThreshold;
57 constexpr int64_t
58     XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads;
59 constexpr int64_t
60     XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations;
61 
XlaCompilationCache(xla::LocalClient * client,DeviceType device_type)62 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
63                                          DeviceType device_type)
64     : client_(client), device_type_(std::move(device_type)) {}
65 
~XlaCompilationCache()66 XlaCompilationCache::~XlaCompilationCache() {
67   // Ensure any use of our programs have completed by waiting for all stream
68   // executors to complete.
69   for (auto* executor : client_->backend().stream_executors()) {
70     bool ok = executor->SynchronizeAllActivity();
71     if (!ok) {
72       LOG(ERROR) << "Error synchronizing activity while waiting for all "
73                     "programs to complete";
74     }
75   }
76   // Wait for all outstanding compilations to finish.
77   // Resetting the pointer explicitly in the top level destructor.
78   // Without this, the pointer would be reset when the AsyncCompilationState
79   // is destructed, which is dependent on the order of the members in the
80   // XlaCompilationCache class, which is error prone if the order changes.
81   async_compilation_state_.compiler_threads.reset();
82   // TODO(b/110813685): Think about the program ownership model. Programs are
83   // currently owned by the compilation cache which means we must wait for
84   // program completion in the destructor. There are multiple compilation caches
85   // around, which complicates things a little. Perhaps having programs be
86   // shared_ptrs (an invasive change) would make the model easier to reason
87   // about?
88 }
89 
DebugString() const90 string XlaCompilationCache::DebugString() const {
91   return "XLA JIT compilation cache";
92 }
93 
94 // Compute a string signature which encodes the shapes of the
95 // arguments in the supplied list.
HumanString() const96 string XlaCompilationCache::Signature::HumanString() const {
97   string result = name;
98   for (const auto& a : arg_shapes) {
99     absl::StrAppend(&result, ",", DataTypeString(a.first));
100     absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
101   }
102 
103   for (const auto& v : arg_values) {
104     absl::StrAppend(&result, "; ", v.DebugString());
105   }
106   return result;
107 }
108 
operator ==(const Signature & other) const109 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
110   if (name != other.name) return false;
111   if (arg_shapes != other.arg_shapes) return false;
112 
113   if (arg_values.size() != other.arg_values.size()) return false;
114   for (int i = 0, end = arg_values.size(); i < end; ++i) {
115     if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
116         arg_values[i].shape() != other.arg_values[i].shape() ||
117         arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
118       return false;
119     }
120   }
121   return true;
122 }
123 
operator ()(const XlaCompilationCache::Signature & signature) const124 uint64 XlaCompilationCache::Signature::Hash::operator()(
125     const XlaCompilationCache::Signature& signature) const {
126   uint64 h = std::hash<string>()(signature.name);
127   for (const auto& arg : signature.arg_shapes) {
128     h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
129     h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
130     for (int dim : arg.second) {
131       h = Hash64Combine(h, std::hash<int>()(dim));
132     }
133   }
134   for (const auto& arg : signature.arg_values) {
135     h = Hash64Combine(
136         h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
137   }
138   return h;
139 }
140 
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)141 StatusOr<XlaCompilationCache::Signature> XlaCompilationCache::BuildSignature(
142     const NameAttrList& function,
143     absl::Span<const XlaCompiler::Argument> args) {
144   Signature signature;
145   signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
146 
147   for (const XlaCompiler::Argument& arg : args) {
148     switch (arg.kind) {
149       case XlaCompiler::Argument::kConstant:
150       case XlaCompiler::Argument::kConstantResource:
151         signature.arg_values.push_back(arg.constant_value);
152         break;
153       case XlaCompiler::Argument::kParameter:
154       case XlaCompiler::Argument::kResource:
155         signature.arg_shapes.emplace_back(arg.type,
156                                           arg.DimensionSizesAsInlinedVector());
157         break;
158       default:
159         return errors::InvalidArgument(
160             "Unhandled argument kind in XlaCompilationCache: ",
161             arg.HumanString());
162     }
163   }
164   return std::move(signature);
165 }
166 
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)167 Status XlaCompilationCache::BuildExecutable(
168     const XlaCompiler::Options& options,
169     const XlaCompiler::CompilationResult& result,
170     std::unique_ptr<xla::LocalExecutable>* executable) {
171   VLOG(2) << "Compiling to local executable";
172 
173   std::vector<const xla::Shape*> argument_layouts(
174       result.xla_input_shapes.size());
175   for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
176     argument_layouts[i] = &result.xla_input_shapes[i];
177   }
178   xla::ExecutableBuildOptions build_options;
179   if (result.collective_reduce_info) {
180     build_options.set_num_replicas(result.collective_reduce_info->group_size);
181   }
182   build_options.set_device_ordinal(options.device_ordinal != -1
183                                        ? options.device_ordinal
184                                        : client_->default_device_ordinal());
185   build_options.set_result_layout(result.xla_output_shape);
186   build_options.set_device_allocator(options.device_allocator.get());
187   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
188   build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
189       options.detailed_logging);
190   TF_ASSIGN_OR_RETURN(
191       auto executables,
192       client_->Compile(*result.computation, argument_layouts, build_options));
193   TF_RET_CHECK(executables.size() == 1);
194   *executable = std::move(executables[0]);
195   return Status::OK();
196 }
197 
Compile(const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)198 Status XlaCompilationCache::Compile(
199     const XlaCompiler::Options& options, const NameAttrList& function,
200     const std::vector<XlaCompiler::Argument>& args,
201     const XlaCompiler::CompileOptions& compile_options,
202     CompileMode compile_mode,
203     const XlaCompiler::CompilationResult** out_compilation_result,
204     xla::LocalExecutable** out_executable) {
205   return CompileImpl(compile_options, options, function, args, /*ctx=*/nullptr,
206                      CompileScope::kFunction, compile_mode,
207                      out_compilation_result, out_executable);
208 }
209 
ShouldBeMegamorphic(int64_t compile_count,int64_t execution_count)210 static bool ShouldBeMegamorphic(int64_t compile_count,
211                                 int64_t execution_count) {
212   const int64_t kCompileThreshold = 10;
213   const int64_t kMinExecutionsPerCompile = 50;
214 
215   // This heuristic is trying to capture the following property: have we sunk a
216   // certain minimum amount of compile time into the cluster that didn't quite
217   // "pay off"?
218   return compile_count > kCompileThreshold &&
219          execution_count < kMinExecutionsPerCompile * compile_count;
220 }
221 
CreateGraph(const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types)222 StatusOr<std::unique_ptr<Graph>> CreateGraph(
223     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
224     absl::Span<const DataType> result_types) {
225   // TODO(b/74182462): We implement this by creating a new dummy Graph including
226   // _Arg nodes, and let CompileGraph walk it. This could be optimized.
227   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
228 
229   Status status;
230   // First create the actual node we care about computing.
231   Node* main_node = graph->AddNode(node_def, &status);
232   TF_RETURN_IF_ERROR(status);
233 
234   // Create dummy _Arg nodes. Link these to `node` and also via a control
235   // dependency edge to the _SOURCE node.
236   for (int64_t i = 0, end = args.size(); i < end; ++i) {
237     Node* node;
238     string arg_name = absl::StrCat("_arg", i);
239     Status status =
240         NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
241             .ControlInput(graph->source_node())
242             .Attr("T", args[i].kind == XlaCompiler::Argument::kResource
243                            ? DT_RESOURCE
244                            : args[i].type)
245             .Attr("index", i)
246             .Finalize(graph.get(), &node);
247     TF_RETURN_IF_ERROR(status);
248     graph->AddEdge(node, 0, main_node, i);
249   }
250 
251   // Similarly with return values, create dummy _Retval nodes fed by `node`.
252   for (int64_t i = 0, end = result_types.size(); i < end; ++i) {
253     Node* node;
254     string retval_name = absl::StrCat("_retval", i);
255     Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
256                         .Input(main_node, i)
257                         .Attr("T", result_types[i])
258                         .Attr("index", i)
259                         .Finalize(graph.get(), &node);
260     TF_RETURN_IF_ERROR(status);
261   }
262   FixupSourceAndSinkEdges(graph.get());
263   return graph;
264 }
265 
XlaSingleOpToHlo(XlaCompiler * compiler,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,XlaCompiler::CompilationResult * compilation_result)266 Status XlaSingleOpToHlo(XlaCompiler* compiler,
267                         const XlaCompiler::Options& options,
268                         const std::vector<XlaCompiler::Argument>& args,
269                         OpKernelContext* ctx,
270                         const XlaCompiler::CompileOptions& compile_options,
271                         XlaCompiler::CompilationResult* compilation_result) {
272   std::vector<DataType> result_dtypes(ctx->num_outputs());
273   for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
274     result_dtypes[i] = ctx->expected_output_dtype(i);
275   }
276 
277   const NodeDef& node_def = ctx->op_kernel().def();
278   TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
279 
280   auto compile_with_old_bridge = [&]() {
281     return compiler->CompileGraph(compile_options, node_def.name(),
282                                   std::move(graph), args, compilation_result);
283   };
284 
285   const ConfigProto* config = ctx->function_library()->config_proto();
286   auto bridge_rollout = GetMlirBridgeRolloutState(
287       config ? absl::optional<ConfigProto>(*config) : absl::nullopt);
288   if (bridge_rollout ==
289           ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED ||
290       node_def.op() == "VarIsInitializedOp" ||
291       (bridge_rollout !=
292            ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED &&
293        options.device_type.type_string() != DEVICE_TPU_XLA_JIT)) {
294     return compile_with_old_bridge();
295   }
296 
297   GraphDebugInfo debug_info;
298   std::vector<std::string> control_rets;
299   if (result_dtypes.empty()) {
300     control_rets.push_back(node_def.name());
301   }
302 
303   bool mlir_enabled = (bridge_rollout ==
304                        ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
305   VLOG(1) << "Attempting MLIR bridge."
306           << (mlir_enabled ? " MLIR is explicitly enabled." : "");
307   auto mlir_result = CompileGraphToXlaHlo(
308       *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
309       options.device_type.type_string(), compile_options.use_tuple_arg,
310       /*analyse_graph=*/!mlir_enabled, *options.flib_def, debug_info,
311       options.shape_representation_fn, compilation_result);
312 
313   if (mlir_result.ok() || mlir_enabled) {
314     return mlir_result;
315   }
316 
317   LOG_FIRST_N(WARNING, 5)
318       << "Failed second phase of the MLIR bridge. Will "
319          "retry with the old bridge. MLIR bridge compilation status: "
320       << mlir_result;
321   return compile_with_old_bridge();
322 }
323 
CompileSingleOp(const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)324 Status XlaCompilationCache::CompileSingleOp(
325     const XlaCompiler::Options& options,
326     const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
327     const XlaCompiler::CompileOptions& compile_options,
328     const XlaCompiler::CompilationResult** out_compilation_result,
329     xla::LocalExecutable** out_executable) {
330   const NodeDef& def = ctx->op_kernel().def();
331   NameAttrList name;
332   name.set_name(def.op());
333   *name.mutable_attr() = def.attr();
334   // Remove the "_class" attribute from the attribute set used to create the
335   // compilation cache key. This attribute is information for the colocator
336   // and causes false uniqueness between nodes.
337   name.mutable_attr()->erase("_class");
338   return CompileImpl(compile_options, options, name, args, ctx,
339                      CompileScope::kOp, CompileMode::kStrict,
340                      out_compilation_result, out_executable);
341 }
342 
343 namespace {
344 // Print something that users can search for to definitively ascertain that XLA
345 // was used for their TF model.
346 //
347 // Prints only once to avoid spamming LOG(INFO).
LogOnceXlaCompiledFirstCluster()348 void LogOnceXlaCompiledFirstCluster() {
349   static absl::once_flag log_once;
350   absl::call_once(log_once, [] {
351     LOG(INFO) << "Compiled cluster using XLA!  This line is logged at most "
352                  "once for the lifetime of the process.";
353   });
354 }
355 }  // namespace
356 
CompileStrict(Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)357 Status XlaCompilationCache::CompileStrict(
358     Entry* entry, const XlaCompiler::CompileOptions& compile_options,
359     const XlaCompiler::Options& options,
360     const std::vector<XlaCompiler::Argument>& args,
361     const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
362   tensorflow::Env* env = tensorflow::Env::Default();
363   const uint64 compile_start_us = env->NowMicros();
364 
365   XlaCompiler compiler(options);
366   entry->compile_state = CompileState::kCompiled;
367   entry->compilation_status = [&] {
368     if (scope == CompileScope::kOp) {
369       return XlaSingleOpToHlo(&compiler, options, args, ctx, compile_options,
370                               &entry->compilation_result);
371 
372     } else {
373       CHECK(scope == CompileScope::kFunction);  // Crash OK
374       return compiler.CompileFunction(compile_options, function, args,
375                                       &entry->compilation_result);
376     }
377   }();
378   TF_RETURN_IF_ERROR(entry->compilation_status);
379   TF_RET_CHECK(entry->executable.get() == nullptr);
380   entry->compilation_status =
381       BuildExecutable(options, entry->compilation_result, &entry->executable);
382 
383   const uint64 compile_end_us = env->NowMicros();
384   const uint64 compile_time_us = compile_end_us - compile_start_us;
385   metrics::UpdateXlaCompilationTime(compile_time_us);
386 
387   mutex_lock lock(cluster_compile_stats_mu_);
388   const std::string& function_name = function.name();
389   auto it = cluster_compile_stats_.find(function_name);
390   const uint64 compile_time_s = compile_time_us / 1.0e6;
391   it->second.compile_count++;
392   it->second.cumulative_compile_time_us += compile_time_us;
393   LogOnceXlaCompiledFirstCluster();
394   VLOG(1) << "compiled " << function_name << " " << it->second.compile_count
395           << " times, compile time: " << compile_time_us
396           << " us, cumulative: " << it->second.cumulative_compile_time_us
397           << " us ("
398           << tensorflow::strings::HumanReadableElapsedTime(compile_time_s)
399           << " / "
400           << tensorflow::strings::HumanReadableElapsedTime(
401                  it->second.cumulative_compile_time_us / 1.0e6)
402           << ")";
403 
404   XlaJitCompilationActivity jit_compilation_activity;
405   jit_compilation_activity.set_cluster_name(function_name);
406   jit_compilation_activity.set_compile_count(it->second.compile_count);
407   jit_compilation_activity.set_compile_time_us(compile_time_us);
408   jit_compilation_activity.set_cumulative_compile_time_us(
409       it->second.cumulative_compile_time_us);
410   TF_RETURN_IF_ERROR(BroadcastXlaActivity(std::move(jit_compilation_activity)));
411 
412   return Status::OK();
413 }
414 
CompileAsynchronous(Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)415 Status XlaCompilationCache::CompileAsynchronous(
416     Entry* entry, const XlaCompiler::CompileOptions& compile_options,
417     const XlaCompiler::Options& options,
418     const std::vector<XlaCompiler::Argument>& args,
419     const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
420   // Explicitly capture all required data by value for async compilation.
421   entry->compile_state = CompileState::kCompiling;
422   {
423     mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
424     async_compilation_state_.num_ongoing_compilations++;
425   }
426   // Don't move the above code into the thread function as it synchronously
427   // updates the async compilation state!
428 
429   // When the ThreadPool for the compilation cache is destroyed, it waits for
430   // compilations to have finished. This means that both 'entry' and 'this' will
431   // be alive for the duration of the compilation.
432   // !!Pay attention when additional variables must be captured by this lambda!!
433   // All values are captured by value. Make sure that all pointer values (like
434   // entry) do not get freed until the lambda has finished,\.
435   const std::string& function_name = function.name();
436   async_compilation_state_.compiler_threads->Schedule([=] {
437     Entry local_entry;
438     VLOG(2) << "Starting asynchronous compilation of cluster " << function_name
439             << '.';
440     // We don't need to lock local_entry.mu, but do it anyway to satisfy
441     // thread safety analysis.
442     mutex_lock entry_lock(local_entry.mu);
443     Status s = CompileStrict(&local_entry, compile_options, options, args,
444                              function, ctx, scope);
445     VLOG(2) << "Finished asynchronous compililation of cluster "
446             << function_name << '.';
447     {
448       mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
449       async_compilation_state_.num_ongoing_compilations--;
450     }
451     {  // Populate original entry with compilation result.
452       mutex_lock entry_lock(entry->mu);
453       if (!s.ok()) {
454         entry->compilation_status = s;
455       } else {
456         entry->compilation_status = local_entry.compilation_status;
457       }
458       entry->compilation_result = local_entry.compilation_result;
459       entry->compile_state = local_entry.compile_state;
460       entry->executable = std::move(local_entry.executable);
461     }
462   });
463   return Status::OK();
464 }
465 
ShouldCompileCluster(CompileMode compile_mode,bool is_megamorphic,bool is_first_execution,int64_t current_request_count,const NameAttrList & function)466 bool XlaCompilationCache::ShouldCompileCluster(CompileMode compile_mode,
467                                                bool is_megamorphic,
468                                                bool is_first_execution,
469                                                int64_t current_request_count,
470                                                const NameAttrList& function) {
471   absl::optional<int64> compile_threshold;
472   if (compile_mode == CompileMode::kLazy) {
473     compile_threshold = kDefaultCompilationThreshold;
474   } else if (compile_mode == CompileMode::kAsync) {
475     compile_threshold = 0;  // for now, always compile right away.
476   }
477 
478   if (compile_mode == CompileMode::kStrict) {
479     // Lazy compilation is disabled.
480     return true;
481   }
482 
483   if (is_megamorphic) {
484     BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
485                                 function.name())
486         .IgnoreError();
487     VLOG(2) << "Not compiling cluster " << function.name()
488             << " because it is megamorphic.";
489     return false;
490   }
491 
492   if (is_first_execution) {
493     return true;
494   }
495 
496   if (compile_mode == CompileMode::kAsync) {
497     // Asynchronous compilation is enabled.
498     mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
499     if (async_compilation_state_.num_ongoing_compilations >=
500         async_compilation_state_.kMaxNumOngoingCompilations) {
501       VLOG(2) << "Not asynchronously compiling cluster " << function.name()
502               << " because of too many ongoing compilations.";
503       return false;
504     }
505   }
506 
507   bool reached_compile_threshold = current_request_count >= *compile_threshold;
508   if (!reached_compile_threshold) {
509     VLOG(2) << "Not compiling cluster " << function.name()
510             << " because it has not reached compile threshold; threshold is "
511             << *compile_threshold << " execution count "
512             << current_request_count << ".";
513   }
514   return reached_compile_threshold;
515 }
516 
CompileImpl(const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,CompileScope scope,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)517 Status XlaCompilationCache::CompileImpl(
518     const XlaCompiler::CompileOptions& compile_options,
519     const XlaCompiler::Options& options, const NameAttrList& function,
520     const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
521     CompileScope scope, CompileMode compile_mode,
522     const XlaCompiler::CompilationResult** out_compilation_result,
523     xla::LocalExecutable** out_executable) {
524   if (FailOnXlaCompilation()) {
525     return errors::Internal("XLA compilation disabled");
526   }
527   DCHECK_NE(out_executable, nullptr);
528   VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
529 
530   if (VLOG_IS_ON(2)) {
531     VLOG(2) << "num_inputs=" << args.size();
532     for (int i = 0, end = args.size(); i < end; i++) {
533       VLOG(3) << i << ": " << args[i].HumanString();
534     }
535   }
536   TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
537 
538 
539   // The outer lock protects the existence of the cache entry. It does not
540   // protect the contents of the cache entry.
541   Entry* entry;
542   {
543     mutex_lock lock(compile_cache_mu_);
544     // Find or create a cache entry.
545     std::unique_ptr<Entry>& e = cache_[signature];
546     if (!e) {
547       e.reset(new Entry);
548     }
549     entry = e.get();
550   }
551 
552   // We always compile a cluster the very first time it is executed.  This is an
553   // optimistic guess that pays off for statically shaped TensorFlow graphs
554   // (since they get the benefit of XLA right away without waiting for warmup)
555   // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
556   // most one cluster-compilation's worth of compile time).
557   bool is_first_execution;
558 
559   // We avoid compiling clusters that have "gone megamorphic" i.e. have an
560   // excessive amount of shape dynamism.
561   bool is_megamorphic;
562 
563   {
564     mutex_lock lock(cluster_compile_stats_mu_);
565     auto it =
566         cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
567             .first;
568     is_first_execution = it->second.execution_count++ == 0;
569 
570     // The is_megamorphic bit is "sticky".  We assume clusters that have been
571     // observed to be megamorphic once stay megamorphic forever.
572     if (!it->second.is_megamorphic &&
573         ShouldBeMegamorphic(/*compile_count=*/it->second.compile_count,
574                             /*execution_count=*/it->second.execution_count)) {
575       VLOG(1) << "Marking " << function.name()
576               << " as megamorphic, compile_count=" << it->second.compile_count
577               << " execution_count=" << it->second.execution_count;
578       it->second.is_megamorphic = true;
579     }
580 
581     is_megamorphic = it->second.is_megamorphic;
582   }
583 
584   string human_signature;
585   if (VLOG_IS_ON(2)) {
586     human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name();
587     VLOG(2) << "Signature: " << human_signature;
588   }
589 
590   // Acquire the cache entry lock and compile, if necessary.
591   // TODO(phawkins): this locking will need to be restructured when we implement
592   // cache eviction.
593   mutex_lock entry_lock(entry->mu);
594   int64_t current_request_count = ++entry->request_count;
595   VLOG(2) << "Compilation cache entry hit: "
596           << static_cast<int>(entry->compile_state)
597           << " signature: " << human_signature << " with request count "
598           << current_request_count;
599 
600   CompileState state = entry->compile_state;
601   *out_compilation_result = nullptr;
602   *out_executable = nullptr;
603 
604   if (state == CompileState::kUncompiled) {
605     XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
606     if (!ShouldCompileCluster(compile_mode, is_megamorphic, is_first_execution,
607                               current_request_count, function)) {
608       VLOG(2) << "Not compiling for signature: " << human_signature;
609       return Status::OK();
610     } else if (compile_mode == CompileMode::kAsync) {
611       VLOG(2) << "Queueing asynchronous compilation for signature: "
612               << human_signature;
613       TF_RETURN_IF_ERROR(CompileAsynchronous(entry, compile_options, options,
614                                              args, function, ctx, scope));
615       return Status::OK();
616     } else {
617       VLOG(2) << "Instantly compiling for signature: " << human_signature;
618       TF_RETURN_IF_ERROR(CompileStrict(entry, compile_options, options, args,
619                                        function, ctx, scope));
620     }
621   } else if (state == CompileState::kCompiling) {
622     VLOG(2) << "Ongoing asynchronous compilation for signature: "
623             << human_signature;
624     return Status::OK();
625   } else if (state == CompileState::kCompiled) {
626     VLOG(2) << "Already Compiled for signature: " << human_signature;
627   }
628 
629   TF_RETURN_IF_ERROR(entry->compilation_status);
630   *out_compilation_result = &entry->compilation_result;
631   *out_executable = entry->executable.get();
632   return Status::OK();
633 }
634 
635 }  // namespace tensorflow
636