1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_activity.pb.h"
26 #include "tensorflow/compiler/jit/xla_activity_listener.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
28 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
29 #include "tensorflow/compiler/tf2xla/shape_util.h"
30 #include "tensorflow/compiler/tf2xla/type_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
32 #include "tensorflow/compiler/tf2xla/xla_context.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/graph_optimizer.h"
39 #include "tensorflow/core/common_runtime/metrics.h"
40 #include "tensorflow/core/framework/attr_value_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/tpu/tpu_defs.h"
52 #include "tensorflow/core/util/dump_graph.h"
53
54 namespace tensorflow {
55
56 constexpr int64_t XlaCompilationCache::kDefaultCompilationThreshold;
57 constexpr int64_t
58 XlaCompilationCache::AsyncCompilationState::kNumCompilerThreads;
59 constexpr int64_t
60 XlaCompilationCache::AsyncCompilationState::kMaxNumOngoingCompilations;
61
XlaCompilationCache(xla::LocalClient * client,DeviceType device_type)62 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
63 DeviceType device_type)
64 : client_(client), device_type_(std::move(device_type)) {}
65
~XlaCompilationCache()66 XlaCompilationCache::~XlaCompilationCache() {
67 // Ensure any use of our programs have completed by waiting for all stream
68 // executors to complete.
69 for (auto* executor : client_->backend().stream_executors()) {
70 bool ok = executor->SynchronizeAllActivity();
71 if (!ok) {
72 LOG(ERROR) << "Error synchronizing activity while waiting for all "
73 "programs to complete";
74 }
75 }
76 // Wait for all outstanding compilations to finish.
77 // Resetting the pointer explicitly in the top level destructor.
78 // Without this, the pointer would be reset when the AsyncCompilationState
79 // is destructed, which is dependent on the order of the members in the
80 // XlaCompilationCache class, which is error prone if the order changes.
81 async_compilation_state_.compiler_threads.reset();
82 // TODO(b/110813685): Think about the program ownership model. Programs are
83 // currently owned by the compilation cache which means we must wait for
84 // program completion in the destructor. There are multiple compilation caches
85 // around, which complicates things a little. Perhaps having programs be
86 // shared_ptrs (an invasive change) would make the model easier to reason
87 // about?
88 }
89
DebugString() const90 string XlaCompilationCache::DebugString() const {
91 return "XLA JIT compilation cache";
92 }
93
94 // Compute a string signature which encodes the shapes of the
95 // arguments in the supplied list.
HumanString() const96 string XlaCompilationCache::Signature::HumanString() const {
97 string result = name;
98 for (const auto& a : arg_shapes) {
99 absl::StrAppend(&result, ",", DataTypeString(a.first));
100 absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
101 }
102
103 for (const auto& v : arg_values) {
104 absl::StrAppend(&result, "; ", v.DebugString());
105 }
106 return result;
107 }
108
operator ==(const Signature & other) const109 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
110 if (name != other.name) return false;
111 if (arg_shapes != other.arg_shapes) return false;
112
113 if (arg_values.size() != other.arg_values.size()) return false;
114 for (int i = 0, end = arg_values.size(); i < end; ++i) {
115 if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
116 arg_values[i].shape() != other.arg_values[i].shape() ||
117 arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
118 return false;
119 }
120 }
121 return true;
122 }
123
operator ()(const XlaCompilationCache::Signature & signature) const124 uint64 XlaCompilationCache::Signature::Hash::operator()(
125 const XlaCompilationCache::Signature& signature) const {
126 uint64 h = std::hash<string>()(signature.name);
127 for (const auto& arg : signature.arg_shapes) {
128 h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
129 h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
130 for (int dim : arg.second) {
131 h = Hash64Combine(h, std::hash<int>()(dim));
132 }
133 }
134 for (const auto& arg : signature.arg_values) {
135 h = Hash64Combine(
136 h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
137 }
138 return h;
139 }
140
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)141 StatusOr<XlaCompilationCache::Signature> XlaCompilationCache::BuildSignature(
142 const NameAttrList& function,
143 absl::Span<const XlaCompiler::Argument> args) {
144 Signature signature;
145 signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
146
147 for (const XlaCompiler::Argument& arg : args) {
148 switch (arg.kind) {
149 case XlaCompiler::Argument::kConstant:
150 case XlaCompiler::Argument::kConstantResource:
151 signature.arg_values.push_back(arg.constant_value);
152 break;
153 case XlaCompiler::Argument::kParameter:
154 case XlaCompiler::Argument::kResource:
155 signature.arg_shapes.emplace_back(arg.type,
156 arg.DimensionSizesAsInlinedVector());
157 break;
158 default:
159 return errors::InvalidArgument(
160 "Unhandled argument kind in XlaCompilationCache: ",
161 arg.HumanString());
162 }
163 }
164 return std::move(signature);
165 }
166
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)167 Status XlaCompilationCache::BuildExecutable(
168 const XlaCompiler::Options& options,
169 const XlaCompiler::CompilationResult& result,
170 std::unique_ptr<xla::LocalExecutable>* executable) {
171 VLOG(2) << "Compiling to local executable";
172
173 std::vector<const xla::Shape*> argument_layouts(
174 result.xla_input_shapes.size());
175 for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
176 argument_layouts[i] = &result.xla_input_shapes[i];
177 }
178 xla::ExecutableBuildOptions build_options;
179 if (result.collective_reduce_info) {
180 build_options.set_num_replicas(result.collective_reduce_info->group_size);
181 }
182 build_options.set_device_ordinal(options.device_ordinal != -1
183 ? options.device_ordinal
184 : client_->default_device_ordinal());
185 build_options.set_result_layout(result.xla_output_shape);
186 build_options.set_device_allocator(options.device_allocator.get());
187 build_options.set_alias_passthrough_params(options.alias_passthrough_params);
188 build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
189 options.detailed_logging);
190 TF_ASSIGN_OR_RETURN(
191 auto executables,
192 client_->Compile(*result.computation, argument_layouts, build_options));
193 TF_RET_CHECK(executables.size() == 1);
194 *executable = std::move(executables[0]);
195 return Status::OK();
196 }
197
Compile(const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)198 Status XlaCompilationCache::Compile(
199 const XlaCompiler::Options& options, const NameAttrList& function,
200 const std::vector<XlaCompiler::Argument>& args,
201 const XlaCompiler::CompileOptions& compile_options,
202 CompileMode compile_mode,
203 const XlaCompiler::CompilationResult** out_compilation_result,
204 xla::LocalExecutable** out_executable) {
205 return CompileImpl(compile_options, options, function, args, /*ctx=*/nullptr,
206 CompileScope::kFunction, compile_mode,
207 out_compilation_result, out_executable);
208 }
209
ShouldBeMegamorphic(int64_t compile_count,int64_t execution_count)210 static bool ShouldBeMegamorphic(int64_t compile_count,
211 int64_t execution_count) {
212 const int64_t kCompileThreshold = 10;
213 const int64_t kMinExecutionsPerCompile = 50;
214
215 // This heuristic is trying to capture the following property: have we sunk a
216 // certain minimum amount of compile time into the cluster that didn't quite
217 // "pay off"?
218 return compile_count > kCompileThreshold &&
219 execution_count < kMinExecutionsPerCompile * compile_count;
220 }
221
CreateGraph(const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types)222 StatusOr<std::unique_ptr<Graph>> CreateGraph(
223 const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
224 absl::Span<const DataType> result_types) {
225 // TODO(b/74182462): We implement this by creating a new dummy Graph including
226 // _Arg nodes, and let CompileGraph walk it. This could be optimized.
227 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
228
229 Status status;
230 // First create the actual node we care about computing.
231 Node* main_node = graph->AddNode(node_def, &status);
232 TF_RETURN_IF_ERROR(status);
233
234 // Create dummy _Arg nodes. Link these to `node` and also via a control
235 // dependency edge to the _SOURCE node.
236 for (int64_t i = 0, end = args.size(); i < end; ++i) {
237 Node* node;
238 string arg_name = absl::StrCat("_arg", i);
239 Status status =
240 NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
241 .ControlInput(graph->source_node())
242 .Attr("T", args[i].kind == XlaCompiler::Argument::kResource
243 ? DT_RESOURCE
244 : args[i].type)
245 .Attr("index", i)
246 .Finalize(graph.get(), &node);
247 TF_RETURN_IF_ERROR(status);
248 graph->AddEdge(node, 0, main_node, i);
249 }
250
251 // Similarly with return values, create dummy _Retval nodes fed by `node`.
252 for (int64_t i = 0, end = result_types.size(); i < end; ++i) {
253 Node* node;
254 string retval_name = absl::StrCat("_retval", i);
255 Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
256 .Input(main_node, i)
257 .Attr("T", result_types[i])
258 .Attr("index", i)
259 .Finalize(graph.get(), &node);
260 TF_RETURN_IF_ERROR(status);
261 }
262 FixupSourceAndSinkEdges(graph.get());
263 return graph;
264 }
265
XlaSingleOpToHlo(XlaCompiler * compiler,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,XlaCompiler::CompilationResult * compilation_result)266 Status XlaSingleOpToHlo(XlaCompiler* compiler,
267 const XlaCompiler::Options& options,
268 const std::vector<XlaCompiler::Argument>& args,
269 OpKernelContext* ctx,
270 const XlaCompiler::CompileOptions& compile_options,
271 XlaCompiler::CompilationResult* compilation_result) {
272 std::vector<DataType> result_dtypes(ctx->num_outputs());
273 for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
274 result_dtypes[i] = ctx->expected_output_dtype(i);
275 }
276
277 const NodeDef& node_def = ctx->op_kernel().def();
278 TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
279
280 auto compile_with_old_bridge = [&]() {
281 return compiler->CompileGraph(compile_options, node_def.name(),
282 std::move(graph), args, compilation_result);
283 };
284
285 const ConfigProto* config = ctx->function_library()->config_proto();
286 auto bridge_rollout = GetMlirBridgeRolloutState(
287 config ? absl::optional<ConfigProto>(*config) : absl::nullopt);
288 if (bridge_rollout ==
289 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED ||
290 node_def.op() == "VarIsInitializedOp" ||
291 (bridge_rollout !=
292 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED &&
293 options.device_type.type_string() != DEVICE_TPU_XLA_JIT)) {
294 return compile_with_old_bridge();
295 }
296
297 GraphDebugInfo debug_info;
298 std::vector<std::string> control_rets;
299 if (result_dtypes.empty()) {
300 control_rets.push_back(node_def.name());
301 }
302
303 bool mlir_enabled = (bridge_rollout ==
304 ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED);
305 VLOG(1) << "Attempting MLIR bridge."
306 << (mlir_enabled ? " MLIR is explicitly enabled." : "");
307 auto mlir_result = CompileGraphToXlaHlo(
308 *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
309 options.device_type.type_string(), compile_options.use_tuple_arg,
310 /*analyse_graph=*/!mlir_enabled, *options.flib_def, debug_info,
311 options.shape_representation_fn, compilation_result);
312
313 if (mlir_result.ok() || mlir_enabled) {
314 return mlir_result;
315 }
316
317 LOG_FIRST_N(WARNING, 5)
318 << "Failed second phase of the MLIR bridge. Will "
319 "retry with the old bridge. MLIR bridge compilation status: "
320 << mlir_result;
321 return compile_with_old_bridge();
322 }
323
CompileSingleOp(const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)324 Status XlaCompilationCache::CompileSingleOp(
325 const XlaCompiler::Options& options,
326 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
327 const XlaCompiler::CompileOptions& compile_options,
328 const XlaCompiler::CompilationResult** out_compilation_result,
329 xla::LocalExecutable** out_executable) {
330 const NodeDef& def = ctx->op_kernel().def();
331 NameAttrList name;
332 name.set_name(def.op());
333 *name.mutable_attr() = def.attr();
334 // Remove the "_class" attribute from the attribute set used to create the
335 // compilation cache key. This attribute is information for the colocator
336 // and causes false uniqueness between nodes.
337 name.mutable_attr()->erase("_class");
338 return CompileImpl(compile_options, options, name, args, ctx,
339 CompileScope::kOp, CompileMode::kStrict,
340 out_compilation_result, out_executable);
341 }
342
343 namespace {
344 // Print something that users can search for to definitively ascertain that XLA
345 // was used for their TF model.
346 //
347 // Prints only once to avoid spamming LOG(INFO).
LogOnceXlaCompiledFirstCluster()348 void LogOnceXlaCompiledFirstCluster() {
349 static absl::once_flag log_once;
350 absl::call_once(log_once, [] {
351 LOG(INFO) << "Compiled cluster using XLA! This line is logged at most "
352 "once for the lifetime of the process.";
353 });
354 }
355 } // namespace
356
CompileStrict(Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)357 Status XlaCompilationCache::CompileStrict(
358 Entry* entry, const XlaCompiler::CompileOptions& compile_options,
359 const XlaCompiler::Options& options,
360 const std::vector<XlaCompiler::Argument>& args,
361 const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
362 tensorflow::Env* env = tensorflow::Env::Default();
363 const uint64 compile_start_us = env->NowMicros();
364
365 XlaCompiler compiler(options);
366 entry->compile_state = CompileState::kCompiled;
367 entry->compilation_status = [&] {
368 if (scope == CompileScope::kOp) {
369 return XlaSingleOpToHlo(&compiler, options, args, ctx, compile_options,
370 &entry->compilation_result);
371
372 } else {
373 CHECK(scope == CompileScope::kFunction); // Crash OK
374 return compiler.CompileFunction(compile_options, function, args,
375 &entry->compilation_result);
376 }
377 }();
378 TF_RETURN_IF_ERROR(entry->compilation_status);
379 TF_RET_CHECK(entry->executable.get() == nullptr);
380 entry->compilation_status =
381 BuildExecutable(options, entry->compilation_result, &entry->executable);
382
383 const uint64 compile_end_us = env->NowMicros();
384 const uint64 compile_time_us = compile_end_us - compile_start_us;
385 metrics::UpdateXlaCompilationTime(compile_time_us);
386
387 mutex_lock lock(cluster_compile_stats_mu_);
388 const std::string& function_name = function.name();
389 auto it = cluster_compile_stats_.find(function_name);
390 const uint64 compile_time_s = compile_time_us / 1.0e6;
391 it->second.compile_count++;
392 it->second.cumulative_compile_time_us += compile_time_us;
393 LogOnceXlaCompiledFirstCluster();
394 VLOG(1) << "compiled " << function_name << " " << it->second.compile_count
395 << " times, compile time: " << compile_time_us
396 << " us, cumulative: " << it->second.cumulative_compile_time_us
397 << " us ("
398 << tensorflow::strings::HumanReadableElapsedTime(compile_time_s)
399 << " / "
400 << tensorflow::strings::HumanReadableElapsedTime(
401 it->second.cumulative_compile_time_us / 1.0e6)
402 << ")";
403
404 XlaJitCompilationActivity jit_compilation_activity;
405 jit_compilation_activity.set_cluster_name(function_name);
406 jit_compilation_activity.set_compile_count(it->second.compile_count);
407 jit_compilation_activity.set_compile_time_us(compile_time_us);
408 jit_compilation_activity.set_cumulative_compile_time_us(
409 it->second.cumulative_compile_time_us);
410 TF_RETURN_IF_ERROR(BroadcastXlaActivity(std::move(jit_compilation_activity)));
411
412 return Status::OK();
413 }
414
CompileAsynchronous(Entry * entry,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const std::vector<XlaCompiler::Argument> & args,const NameAttrList & function,OpKernelContext * ctx,CompileScope scope)415 Status XlaCompilationCache::CompileAsynchronous(
416 Entry* entry, const XlaCompiler::CompileOptions& compile_options,
417 const XlaCompiler::Options& options,
418 const std::vector<XlaCompiler::Argument>& args,
419 const NameAttrList& function, OpKernelContext* ctx, CompileScope scope) {
420 // Explicitly capture all required data by value for async compilation.
421 entry->compile_state = CompileState::kCompiling;
422 {
423 mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
424 async_compilation_state_.num_ongoing_compilations++;
425 }
426 // Don't move the above code into the thread function as it synchronously
427 // updates the async compilation state!
428
429 // When the ThreadPool for the compilation cache is destroyed, it waits for
430 // compilations to have finished. This means that both 'entry' and 'this' will
431 // be alive for the duration of the compilation.
432 // !!Pay attention when additional variables must be captured by this lambda!!
433 // All values are captured by value. Make sure that all pointer values (like
434 // entry) do not get freed until the lambda has finished,\.
435 const std::string& function_name = function.name();
436 async_compilation_state_.compiler_threads->Schedule([=] {
437 Entry local_entry;
438 VLOG(2) << "Starting asynchronous compilation of cluster " << function_name
439 << '.';
440 // We don't need to lock local_entry.mu, but do it anyway to satisfy
441 // thread safety analysis.
442 mutex_lock entry_lock(local_entry.mu);
443 Status s = CompileStrict(&local_entry, compile_options, options, args,
444 function, ctx, scope);
445 VLOG(2) << "Finished asynchronous compililation of cluster "
446 << function_name << '.';
447 {
448 mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
449 async_compilation_state_.num_ongoing_compilations--;
450 }
451 { // Populate original entry with compilation result.
452 mutex_lock entry_lock(entry->mu);
453 if (!s.ok()) {
454 entry->compilation_status = s;
455 } else {
456 entry->compilation_status = local_entry.compilation_status;
457 }
458 entry->compilation_result = local_entry.compilation_result;
459 entry->compile_state = local_entry.compile_state;
460 entry->executable = std::move(local_entry.executable);
461 }
462 });
463 return Status::OK();
464 }
465
ShouldCompileCluster(CompileMode compile_mode,bool is_megamorphic,bool is_first_execution,int64_t current_request_count,const NameAttrList & function)466 bool XlaCompilationCache::ShouldCompileCluster(CompileMode compile_mode,
467 bool is_megamorphic,
468 bool is_first_execution,
469 int64_t current_request_count,
470 const NameAttrList& function) {
471 absl::optional<int64> compile_threshold;
472 if (compile_mode == CompileMode::kLazy) {
473 compile_threshold = kDefaultCompilationThreshold;
474 } else if (compile_mode == CompileMode::kAsync) {
475 compile_threshold = 0; // for now, always compile right away.
476 }
477
478 if (compile_mode == CompileMode::kStrict) {
479 // Lazy compilation is disabled.
480 return true;
481 }
482
483 if (is_megamorphic) {
484 BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
485 function.name())
486 .IgnoreError();
487 VLOG(2) << "Not compiling cluster " << function.name()
488 << " because it is megamorphic.";
489 return false;
490 }
491
492 if (is_first_execution) {
493 return true;
494 }
495
496 if (compile_mode == CompileMode::kAsync) {
497 // Asynchronous compilation is enabled.
498 mutex_lock lock(async_compilation_state_.async_compilation_state_mu);
499 if (async_compilation_state_.num_ongoing_compilations >=
500 async_compilation_state_.kMaxNumOngoingCompilations) {
501 VLOG(2) << "Not asynchronously compiling cluster " << function.name()
502 << " because of too many ongoing compilations.";
503 return false;
504 }
505 }
506
507 bool reached_compile_threshold = current_request_count >= *compile_threshold;
508 if (!reached_compile_threshold) {
509 VLOG(2) << "Not compiling cluster " << function.name()
510 << " because it has not reached compile threshold; threshold is "
511 << *compile_threshold << " execution count "
512 << current_request_count << ".";
513 }
514 return reached_compile_threshold;
515 }
516
CompileImpl(const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::Options & options,const NameAttrList & function,const std::vector<XlaCompiler::Argument> & args,OpKernelContext * ctx,CompileScope scope,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)517 Status XlaCompilationCache::CompileImpl(
518 const XlaCompiler::CompileOptions& compile_options,
519 const XlaCompiler::Options& options, const NameAttrList& function,
520 const std::vector<XlaCompiler::Argument>& args, OpKernelContext* ctx,
521 CompileScope scope, CompileMode compile_mode,
522 const XlaCompiler::CompilationResult** out_compilation_result,
523 xla::LocalExecutable** out_executable) {
524 if (FailOnXlaCompilation()) {
525 return errors::Internal("XLA compilation disabled");
526 }
527 DCHECK_NE(out_executable, nullptr);
528 VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
529
530 if (VLOG_IS_ON(2)) {
531 VLOG(2) << "num_inputs=" << args.size();
532 for (int i = 0, end = args.size(); i < end; i++) {
533 VLOG(3) << i << ": " << args[i].HumanString();
534 }
535 }
536 TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
537
538
539 // The outer lock protects the existence of the cache entry. It does not
540 // protect the contents of the cache entry.
541 Entry* entry;
542 {
543 mutex_lock lock(compile_cache_mu_);
544 // Find or create a cache entry.
545 std::unique_ptr<Entry>& e = cache_[signature];
546 if (!e) {
547 e.reset(new Entry);
548 }
549 entry = e.get();
550 }
551
552 // We always compile a cluster the very first time it is executed. This is an
553 // optimistic guess that pays off for statically shaped TensorFlow graphs
554 // (since they get the benefit of XLA right away without waiting for warmup)
555 // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
556 // most one cluster-compilation's worth of compile time).
557 bool is_first_execution;
558
559 // We avoid compiling clusters that have "gone megamorphic" i.e. have an
560 // excessive amount of shape dynamism.
561 bool is_megamorphic;
562
563 {
564 mutex_lock lock(cluster_compile_stats_mu_);
565 auto it =
566 cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
567 .first;
568 is_first_execution = it->second.execution_count++ == 0;
569
570 // The is_megamorphic bit is "sticky". We assume clusters that have been
571 // observed to be megamorphic once stay megamorphic forever.
572 if (!it->second.is_megamorphic &&
573 ShouldBeMegamorphic(/*compile_count=*/it->second.compile_count,
574 /*execution_count=*/it->second.execution_count)) {
575 VLOG(1) << "Marking " << function.name()
576 << " as megamorphic, compile_count=" << it->second.compile_count
577 << " execution_count=" << it->second.execution_count;
578 it->second.is_megamorphic = true;
579 }
580
581 is_megamorphic = it->second.is_megamorphic;
582 }
583
584 string human_signature;
585 if (VLOG_IS_ON(2)) {
586 human_signature = VLOG_IS_ON(3) ? signature.HumanString() : function.name();
587 VLOG(2) << "Signature: " << human_signature;
588 }
589
590 // Acquire the cache entry lock and compile, if necessary.
591 // TODO(phawkins): this locking will need to be restructured when we implement
592 // cache eviction.
593 mutex_lock entry_lock(entry->mu);
594 int64_t current_request_count = ++entry->request_count;
595 VLOG(2) << "Compilation cache entry hit: "
596 << static_cast<int>(entry->compile_state)
597 << " signature: " << human_signature << " with request count "
598 << current_request_count;
599
600 CompileState state = entry->compile_state;
601 *out_compilation_result = nullptr;
602 *out_executable = nullptr;
603
604 if (state == CompileState::kUncompiled) {
605 XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
606 if (!ShouldCompileCluster(compile_mode, is_megamorphic, is_first_execution,
607 current_request_count, function)) {
608 VLOG(2) << "Not compiling for signature: " << human_signature;
609 return Status::OK();
610 } else if (compile_mode == CompileMode::kAsync) {
611 VLOG(2) << "Queueing asynchronous compilation for signature: "
612 << human_signature;
613 TF_RETURN_IF_ERROR(CompileAsynchronous(entry, compile_options, options,
614 args, function, ctx, scope));
615 return Status::OK();
616 } else {
617 VLOG(2) << "Instantly compiling for signature: " << human_signature;
618 TF_RETURN_IF_ERROR(CompileStrict(entry, compile_options, options, args,
619 function, ctx, scope));
620 }
621 } else if (state == CompileState::kCompiling) {
622 VLOG(2) << "Ongoing asynchronous compilation for signature: "
623 << human_signature;
624 return Status::OK();
625 } else if (state == CompileState::kCompiled) {
626 VLOG(2) << "Already Compiled for signature: " << human_signature;
627 }
628
629 TF_RETURN_IF_ERROR(entry->compilation_status);
630 *out_compilation_result = &entry->compilation_result;
631 *out_executable = entry->executable.get();
632 return Status::OK();
633 }
634
635 } // namespace tensorflow
636