1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17
18 #include <numeric>
19
20 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_activity.pb.h"
26 #include "tensorflow/compiler/jit/xla_activity_listener.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
28 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
29 #include "tensorflow/compiler/tf2xla/shape_util.h"
30 #include "tensorflow/compiler/tf2xla/type_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
32 #include "tensorflow/compiler/tf2xla/xla_context.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/graph_optimizer.h"
39 #include "tensorflow/core/common_runtime/metrics.h"
40 #include "tensorflow/core/framework/attr_value_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/util/dump_graph.h"
52
53 namespace tensorflow {
54
55 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
56
XlaCompilationCache(xla::LocalClient * client,DeviceType device_type)57 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
58 DeviceType device_type)
59 : client_(client), device_type_(std::move(device_type)) {}
60
~XlaCompilationCache()61 XlaCompilationCache::~XlaCompilationCache() {
62 // Ensure any use of our programs have completed by waiting for all stream
63 // executors to complete.
64 for (auto* executor : client_->backend().stream_executors()) {
65 bool ok = executor->SynchronizeAllActivity();
66 if (!ok) {
67 LOG(ERROR) << "Error synchronizing activity while waiting for all "
68 "programs to complete";
69 }
70 }
71 // TODO(b/110813685): Think about the program ownership model. Programs are
72 // currently owned by the compilation cache which means we must wait for
73 // program completion in the destructor. There are multiple compilation caches
74 // around, which complicates things a little. Perhaps having programs be
75 // shared_ptrs (an invasive change) would make the model easier to reason
76 // about?
77 }
78
DebugString() const79 string XlaCompilationCache::DebugString() const {
80 return "XLA JIT compilation cache";
81 }
82
83 // Compute a string signature which encodes the shapes of the
84 // arguments in the supplied list.
HumanString() const85 string XlaCompilationCache::Signature::HumanString() const {
86 string result = name;
87 for (const auto& a : arg_shapes) {
88 absl::StrAppend(&result, ",", DataTypeString(a.first));
89 absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
90 }
91
92 for (const auto& v : arg_values) {
93 absl::StrAppend(&result, "; ", v.DebugString());
94 }
95 return result;
96 }
97
operator ==(const Signature & other) const98 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
99 if (name != other.name) return false;
100 if (arg_shapes != other.arg_shapes) return false;
101
102 if (arg_values.size() != other.arg_values.size()) return false;
103 for (int i = 0, end = arg_values.size(); i < end; ++i) {
104 if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
105 arg_values[i].shape() != other.arg_values[i].shape() ||
106 arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
107 return false;
108 }
109 }
110 return true;
111 }
112
operator ()(const XlaCompilationCache::Signature & signature) const113 uint64 XlaCompilationCache::Signature::Hash::operator()(
114 const XlaCompilationCache::Signature& signature) const {
115 uint64 h = std::hash<string>()(signature.name);
116 for (const auto& arg : signature.arg_shapes) {
117 h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
118 h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
119 for (int dim : arg.second) {
120 h = Hash64Combine(h, std::hash<int>()(dim));
121 }
122 }
123 for (const auto& arg : signature.arg_values) {
124 h = Hash64Combine(
125 h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
126 }
127 return h;
128 }
129
130 xla::StatusOr<XlaCompilationCache::Signature>
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)131 XlaCompilationCache::BuildSignature(
132 const NameAttrList& function,
133 absl::Span<const XlaCompiler::Argument> args) {
134 Signature signature;
135 signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
136
137 for (const XlaCompiler::Argument& arg : args) {
138 switch (arg.kind) {
139 case XlaCompiler::Argument::kConstant:
140 case XlaCompiler::Argument::kConstantResource:
141 signature.arg_values.push_back(arg.constant_value);
142 break;
143 case XlaCompiler::Argument::kParameter:
144 case XlaCompiler::Argument::kResource:
145 signature.arg_shapes.emplace_back(arg.type,
146 arg.DimensionSizesAsInlinedVector());
147 break;
148 default:
149 return errors::InvalidArgument(
150 "Unhandled argument kind in XlaCompilationCache: ",
151 arg.HumanString());
152 }
153 }
154 return std::move(signature);
155 }
156
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)157 Status XlaCompilationCache::BuildExecutable(
158 const XlaCompiler::Options& options,
159 const XlaCompiler::CompilationResult& result,
160 std::unique_ptr<xla::LocalExecutable>* executable) {
161 VLOG(2) << "Compiling to local executable";
162
163 std::vector<const xla::Shape*> argument_layouts(
164 result.xla_input_shapes.size());
165 for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
166 argument_layouts[i] = &result.xla_input_shapes[i];
167 }
168 xla::ExecutableBuildOptions build_options;
169 build_options.set_device_ordinal(options.device_ordinal != -1
170 ? options.device_ordinal
171 : client_->default_device_ordinal());
172 build_options.set_result_layout(result.xla_output_shape);
173 build_options.set_device_allocator(options.device_allocator);
174 build_options.set_alias_passthrough_params(options.alias_passthrough_params);
175 build_options.mutable_debug_options()->set_xla_detailed_logging(
176 options.detailed_logging);
177 TF_ASSIGN_OR_RETURN(
178 auto executables,
179 client_->Compile(*result.computation, argument_layouts, build_options));
180 TF_RET_CHECK(executables.size() == 1);
181 *executable = std::move(executables[0]);
182 return Status::OK();
183 }
184
Compile(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)185 Status XlaCompilationCache::Compile(
186 const XlaCompiler::Options& options, const NameAttrList& function,
187 absl::Span<const XlaCompiler::Argument> args,
188 const XlaCompiler::CompileOptions& compile_options,
189 CompileMode compile_mode,
190 const XlaCompiler::CompilationResult** out_compilation_result,
191 xla::LocalExecutable** out_executable) {
192 absl::optional<int64> compile_threshold;
193 if (compile_mode == CompileMode::kLazy) {
194 compile_threshold = kDefaultCompilationThreshold;
195 }
196 auto compile_fn = [&](XlaCompiler* compiler,
197 XlaCompiler::CompilationResult* result) {
198 return compiler->CompileFunction(compile_options, function, args, result);
199 };
200 return CompileImpl(options, function, args, compile_fn,
201 /*compile_threshold=*/compile_threshold,
202 out_compilation_result, out_executable);
203 }
204
ShouldBeMegamorphic(int64 compile_count,int64 execution_count)205 static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
206 const int64 kCompileThreshold = 10;
207 const int64 kMinExecutionsPerCompile = 50;
208
209 // This heuristic is trying to capture the following property: have we sunk a
210 // certain minimum amount of compile time into the cluster that didn't quite
211 // "pay off"?
212 return compile_count > kCompileThreshold &&
213 execution_count < kMinExecutionsPerCompile * compile_count;
214 }
215
216 // Creates a simple graph using the specified op as the only op apart from the
217 // arg and retval nodes.
CreateGraph(const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types)218 static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
219 const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
220 absl::Span<const DataType> result_types) {
221 // TODO(b/74182462): We implement this by creating a new dummy Graph including
222 // _Arg nodes, and let CompileGraph walk it. This could be optimized.
223 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
224
225 Status status;
226 // First create the actual node we care about computing.
227 Node* main_node = graph->AddNode(node_def, &status);
228 TF_RETURN_IF_ERROR(status);
229
230 // Create dummy _Arg nodes. Link these to `node` and also via a control
231 // dependency edge to the _SOURCE node.
232 for (int64 i = 0, end = args.size(); i < end; ++i) {
233 Node* node;
234 string arg_name = absl::StrCat("_arg", i);
235 Status status =
236 NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
237 .ControlInput(graph->source_node())
238 .Attr("T", args[i].kind == XlaCompiler::Argument::kResource
239 ? DT_RESOURCE
240 : args[i].type)
241 .Attr("index", i)
242 .Finalize(graph.get(), &node);
243 TF_RETURN_IF_ERROR(status);
244 graph->AddEdge(node, 0, main_node, i);
245 }
246
247 // Similarly with return values, create dummy _Retval nodes fed by `node`.
248 for (int64 i = 0, end = result_types.size(); i < end; ++i) {
249 Node* node;
250 string retval_name = absl::StrCat("_retval", i);
251 Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
252 .Input(main_node, i)
253 .Attr("T", result_types[i])
254 .Attr("index", i)
255 .Finalize(graph.get(), &node);
256 TF_RETURN_IF_ERROR(status);
257 }
258 FixupSourceAndSinkEdges(graph.get());
259 return graph;
260 }
261
CompileSingleOp(const XlaCompiler::Options & options,absl::Span<const XlaCompiler::Argument> args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)262 Status XlaCompilationCache::CompileSingleOp(
263 const XlaCompiler::Options& options,
264 absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
265 const XlaCompiler::CompileOptions& compile_options,
266 const XlaCompiler::CompilationResult** out_compilation_result,
267 xla::LocalExecutable** out_executable) {
268 const NodeDef& def = ctx->op_kernel().def();
269 NameAttrList name;
270 name.set_name(def.op());
271 *name.mutable_attr() = def.attr();
272 // Remove the "_class" attribute from the attribute set used to create the
273 // compilation cache key. This attribute is information for the colocator
274 // and causes false uniqueness between nodes.
275 name.mutable_attr()->erase("_class");
276 auto compile_op = [&](XlaCompiler* compiler,
277 XlaCompiler::CompilationResult* result) {
278 std::vector<DataType> result_dtypes(ctx->num_outputs());
279 for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
280 result_dtypes[i] = ctx->expected_output_dtype(i);
281 }
282
283 const NodeDef& node_def = ctx->op_kernel().def();
284 TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
285
286 const ConfigProto* config = ctx->function_library()->config_proto();
287 // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
288 bool use_mlir = config &&
289 GetMlirBridgeRolloutPolicy(
290 *graph, *config, /*uses_uninitialized_resource_args=*/
291 AnyUninitializedResourceArg(args)) ==
292 MlirBridgeRolloutPolicy::kEnabledByUser &&
293 node_def.op() != "VarIsInitializedOp";
294 if (!use_mlir) {
295 return compiler->CompileGraph(compile_options, node_def.name(),
296 std::move(graph), args, result);
297 }
298
299 VLOG(1) << "Using MLIR bridge";
300 GraphDebugInfo debug_info;
301 std::vector<std::string> control_rets;
302 if (result_dtypes.empty()) {
303 control_rets.push_back(node_def.name());
304 }
305 return CompileGraphToXlaHlo(
306 *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
307 options.device_type.type_string(), compile_options.use_tuple_arg,
308 *options.flib_def, debug_info, options.shape_representation_fn, result);
309 };
310 return CompileImpl(options, name, args, compile_op,
311 /*compile_threshold=*/absl::nullopt,
312 out_compilation_result, out_executable);
313 }
314
315 namespace {
316 // Print something that users can search for to definitively ascertain that XLA
317 // was used for their TF model.
318 //
319 // Prints only once to avoid spamming LOG(INFO).
LogOnceXlaCompiledFirstCluster()320 void LogOnceXlaCompiledFirstCluster() {
321 static absl::once_flag log_once;
322 absl::call_once(log_once, [] {
323 LOG(INFO) << "Compiled cluster using XLA! This line is logged at most "
324 "once for the lifetime of the process.";
325 });
326 }
327 } // namespace
328
CompileImpl(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const std::function<Status (XlaCompiler * compiler,XlaCompiler::CompilationResult *)> & compile_fn,absl::optional<int64> compile_threshold,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)329 Status XlaCompilationCache::CompileImpl(
330 const XlaCompiler::Options& options, const NameAttrList& function,
331 absl::Span<const XlaCompiler::Argument> args,
332 const std::function<Status(XlaCompiler* compiler,
333 XlaCompiler::CompilationResult*)>& compile_fn,
334 absl::optional<int64> compile_threshold,
335 const XlaCompiler::CompilationResult** out_compilation_result,
336 xla::LocalExecutable** out_executable) {
337 if (FailOnXlaCompilation()) {
338 return errors::Internal("XLA compilation disabled");
339 }
340
341 DCHECK_NE(out_executable, nullptr);
342 VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
343
344 if (VLOG_IS_ON(2)) {
345 VLOG(2) << "num_inputs=" << args.size();
346 for (int i = 0, end = args.size(); i < end; i++) {
347 VLOG(3) << i << ": " << args[i].HumanString();
348 }
349 }
350
351 TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
352 VLOG(2) << "Signature: " << signature.HumanString();
353
354 // The outer lock protects the existence of the cache entry. It does not
355 // protect the contents of the cache entry.
356 Entry* entry;
357 {
358 mutex_lock lock(compile_cache_mu_);
359 // Find or create a cache entry.
360 std::unique_ptr<Entry>& e = cache_[signature];
361 if (!e) {
362 e.reset(new Entry);
363 }
364 entry = e.get();
365 }
366
367 // We always compile a cluster the very first time it is executed. This is an
368 // optimistic guess that pays off for statically shaped TensorFlow graphs
369 // (since they get the benefit of XLA right away without waiting for warmup)
370 // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
371 // most one cluster-compilation's worth of compile time).
372 bool is_first_execution;
373
374 // We avoid compiling clusters that have "gone megamorphic" i.e. have an
375 // excessive amount of shape dynamism.
376 bool is_megamorphic;
377
378 {
379 mutex_lock lock(cluster_compile_stats_mu_);
380 auto it =
381 cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
382 .first;
383 is_first_execution = it->second.execution_count++ == 0;
384
385 // The is_megamorphic bit is "sticky". We assume clusters that have been
386 // observed to be megamorphic once stay megamorphic forever.
387 if (!it->second.is_megamorphic &&
388 ShouldBeMegamorphic(/*compile_count=*/it->second.compile_count,
389 /*execution_count=*/it->second.execution_count)) {
390 VLOG(1) << "Marking " << function.name()
391 << " as megamorphic, compile_count=" << it->second.compile_count
392 << " execution_count=" << it->second.execution_count;
393 it->second.is_megamorphic = true;
394 }
395
396 is_megamorphic = it->second.is_megamorphic;
397 }
398
399 // Acquire the cache entry lock and compile, if necessary.
400 // TODO(phawkins): this locking will need to be restructured when we implement
401 // cache eviction.
402 mutex_lock entry_lock(entry->mu);
403 int64 current_request_count = ++entry->request_count;
404 VLOG(2) << "Compilation cache entry hit: " << entry->compiled
405 << " signature: " << signature.HumanString() << " with request count "
406 << current_request_count << " and compile threshold "
407 << compile_threshold.value_or(0);
408 if (!entry->compiled) {
409 XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
410 const bool should_compile = [&] {
411 if (!compile_threshold.has_value()) {
412 // Lazy compilation is disabled.
413 return true;
414 }
415
416 if (is_megamorphic) {
417 BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
418 function.name())
419 .IgnoreError();
420 VLOG(3) << "Not compiling cluster " << function.name()
421 << " because it is megamorphic.";
422 return false;
423 }
424
425 if (is_first_execution) {
426 return true;
427 }
428
429 bool reached_compile_threshold =
430 current_request_count >= *compile_threshold;
431 if (!reached_compile_threshold) {
432 VLOG(3)
433 << "Not compiling cluster " << function.name()
434 << " because it has not reached compile threshold; threshold is "
435 << *compile_threshold << " execution count "
436 << current_request_count << ".";
437 }
438 return reached_compile_threshold;
439 }();
440
441 if (!should_compile) {
442 VLOG(2) << "Not compiling for signature: " << signature.HumanString();
443 *out_compilation_result = nullptr;
444 *out_executable = nullptr;
445 return Status::OK();
446 }
447
448 tensorflow::Env* env = tensorflow::Env::Default();
449 const uint64 compile_start_us = env->NowMicros();
450 // Do the actual JIT compilation without holding the lock (it can take
451 // a long time.)
452
453 XlaCompiler compiler(options);
454 entry->compiled = true;
455
456 entry->compilation_status =
457 compile_fn(&compiler, &entry->compilation_result);
458 TF_RETURN_IF_ERROR(entry->compilation_status);
459 CHECK_EQ(entry->executable.get(), nullptr);
460 entry->compilation_status =
461 BuildExecutable(options, entry->compilation_result, &entry->executable);
462
463 const uint64 compile_end_us = env->NowMicros();
464 const uint64 compile_time_us = compile_end_us - compile_start_us;
465 metrics::UpdateXlaCompilationTime(compile_time_us);
466 {
467 mutex_lock lock(cluster_compile_stats_mu_);
468 auto it = cluster_compile_stats_.find(function.name());
469 it->second.compile_count++;
470 it->second.cumulative_compile_time_us += compile_time_us;
471 LogOnceXlaCompiledFirstCluster();
472 VLOG(1) << "compiled " << function.name() << " "
473 << it->second.compile_count
474 << " times, compile time: " << compile_time_us
475 << " us, cumulative: " << it->second.cumulative_compile_time_us
476 << " us ("
477 << tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
478 1.0e6)
479 << " / "
480 << tensorflow::strings::HumanReadableElapsedTime(
481 it->second.cumulative_compile_time_us / 1.0e6)
482 << ")";
483
484 XlaJitCompilationActivity jit_compilation_activity;
485 jit_compilation_activity.set_cluster_name(function.name());
486 jit_compilation_activity.set_compile_count(it->second.compile_count);
487 jit_compilation_activity.set_compile_time_us(compile_time_us);
488 jit_compilation_activity.set_cumulative_compile_time_us(
489 it->second.cumulative_compile_time_us);
490
491 TF_RETURN_IF_ERROR(
492 BroadcastXlaActivity(std::move(jit_compilation_activity)));
493 }
494 }
495 TF_RETURN_IF_ERROR(entry->compilation_status);
496 *out_compilation_result = &entry->compilation_result;
497 *out_executable = entry->executable.get();
498 return Status::OK();
499 }
500
501 } // namespace tensorflow
502