• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17 
18 #include <numeric>
19 
20 #include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/xla_activity.pb.h"
26 #include "tensorflow/compiler/jit/xla_activity_listener.h"
27 #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
28 #include "tensorflow/compiler/mlir/utils/array_container_utils.h"
29 #include "tensorflow/compiler/tf2xla/shape_util.h"
30 #include "tensorflow/compiler/tf2xla/type_util.h"
31 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
32 #include "tensorflow/compiler/tf2xla/xla_context.h"
33 #include "tensorflow/compiler/xla/client/client_library.h"
34 #include "tensorflow/compiler/xla/util.h"
35 #include "tensorflow/core/common_runtime/device.h"
36 #include "tensorflow/core/common_runtime/function.h"
37 #include "tensorflow/core/common_runtime/graph_constructor.h"
38 #include "tensorflow/core/common_runtime/graph_optimizer.h"
39 #include "tensorflow/core/common_runtime/metrics.h"
40 #include "tensorflow/core/framework/attr_value_util.h"
41 #include "tensorflow/core/framework/op_kernel.h"
42 #include "tensorflow/core/framework/types.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/lib/hash/hash.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/errors.h"
48 #include "tensorflow/core/platform/logging.h"
49 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 
53 namespace tensorflow {
54 
55 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
56 
XlaCompilationCache(xla::LocalClient * client,DeviceType device_type)57 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
58                                          DeviceType device_type)
59     : client_(client), device_type_(std::move(device_type)) {}
60 
~XlaCompilationCache()61 XlaCompilationCache::~XlaCompilationCache() {
62   // Ensure any use of our programs have completed by waiting for all stream
63   // executors to complete.
64   for (auto* executor : client_->backend().stream_executors()) {
65     bool ok = executor->SynchronizeAllActivity();
66     if (!ok) {
67       LOG(ERROR) << "Error synchronizing activity while waiting for all "
68                     "programs to complete";
69     }
70   }
71   // TODO(b/110813685): Think about the program ownership model. Programs are
72   // currently owned by the compilation cache which means we must wait for
73   // program completion in the destructor. There are multiple compilation caches
74   // around, which complicates things a little. Perhaps having programs be
75   // shared_ptrs (an invasive change) would make the model easier to reason
76   // about?
77 }
78 
DebugString() const79 string XlaCompilationCache::DebugString() const {
80   return "XLA JIT compilation cache";
81 }
82 
83 // Compute a string signature which encodes the shapes of the
84 // arguments in the supplied list.
HumanString() const85 string XlaCompilationCache::Signature::HumanString() const {
86   string result = name;
87   for (const auto& a : arg_shapes) {
88     absl::StrAppend(&result, ",", DataTypeString(a.first));
89     absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
90   }
91 
92   for (const auto& v : arg_values) {
93     absl::StrAppend(&result, "; ", v.DebugString());
94   }
95   return result;
96 }
97 
operator ==(const Signature & other) const98 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
99   if (name != other.name) return false;
100   if (arg_shapes != other.arg_shapes) return false;
101 
102   if (arg_values.size() != other.arg_values.size()) return false;
103   for (int i = 0, end = arg_values.size(); i < end; ++i) {
104     if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
105         arg_values[i].shape() != other.arg_values[i].shape() ||
106         arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
107       return false;
108     }
109   }
110   return true;
111 }
112 
operator ()(const XlaCompilationCache::Signature & signature) const113 uint64 XlaCompilationCache::Signature::Hash::operator()(
114     const XlaCompilationCache::Signature& signature) const {
115   uint64 h = std::hash<string>()(signature.name);
116   for (const auto& arg : signature.arg_shapes) {
117     h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
118     h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
119     for (int dim : arg.second) {
120       h = Hash64Combine(h, std::hash<int>()(dim));
121     }
122   }
123   for (const auto& arg : signature.arg_values) {
124     h = Hash64Combine(
125         h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
126   }
127   return h;
128 }
129 
130 xla::StatusOr<XlaCompilationCache::Signature>
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)131 XlaCompilationCache::BuildSignature(
132     const NameAttrList& function,
133     absl::Span<const XlaCompiler::Argument> args) {
134   Signature signature;
135   signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
136 
137   for (const XlaCompiler::Argument& arg : args) {
138     switch (arg.kind) {
139       case XlaCompiler::Argument::kConstant:
140       case XlaCompiler::Argument::kConstantResource:
141         signature.arg_values.push_back(arg.constant_value);
142         break;
143       case XlaCompiler::Argument::kParameter:
144       case XlaCompiler::Argument::kResource:
145         signature.arg_shapes.emplace_back(arg.type,
146                                           arg.DimensionSizesAsInlinedVector());
147         break;
148       default:
149         return errors::InvalidArgument(
150             "Unhandled argument kind in XlaCompilationCache: ",
151             arg.HumanString());
152     }
153   }
154   return std::move(signature);
155 }
156 
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)157 Status XlaCompilationCache::BuildExecutable(
158     const XlaCompiler::Options& options,
159     const XlaCompiler::CompilationResult& result,
160     std::unique_ptr<xla::LocalExecutable>* executable) {
161   VLOG(2) << "Compiling to local executable";
162 
163   std::vector<const xla::Shape*> argument_layouts(
164       result.xla_input_shapes.size());
165   for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
166     argument_layouts[i] = &result.xla_input_shapes[i];
167   }
168   xla::ExecutableBuildOptions build_options;
169   build_options.set_device_ordinal(options.device_ordinal != -1
170                                        ? options.device_ordinal
171                                        : client_->default_device_ordinal());
172   build_options.set_result_layout(result.xla_output_shape);
173   build_options.set_device_allocator(options.device_allocator);
174   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
175   build_options.mutable_debug_options()->set_xla_detailed_logging(
176       options.detailed_logging);
177   TF_ASSIGN_OR_RETURN(
178       auto executables,
179       client_->Compile(*result.computation, argument_layouts, build_options));
180   TF_RET_CHECK(executables.size() == 1);
181   *executable = std::move(executables[0]);
182   return Status::OK();
183 }
184 
Compile(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)185 Status XlaCompilationCache::Compile(
186     const XlaCompiler::Options& options, const NameAttrList& function,
187     absl::Span<const XlaCompiler::Argument> args,
188     const XlaCompiler::CompileOptions& compile_options,
189     CompileMode compile_mode,
190     const XlaCompiler::CompilationResult** out_compilation_result,
191     xla::LocalExecutable** out_executable) {
192   absl::optional<int64> compile_threshold;
193   if (compile_mode == CompileMode::kLazy) {
194     compile_threshold = kDefaultCompilationThreshold;
195   }
196   auto compile_fn = [&](XlaCompiler* compiler,
197                         XlaCompiler::CompilationResult* result) {
198     return compiler->CompileFunction(compile_options, function, args, result);
199   };
200   return CompileImpl(options, function, args, compile_fn,
201                      /*compile_threshold=*/compile_threshold,
202                      out_compilation_result, out_executable);
203 }
204 
ShouldBeMegamorphic(int64 compile_count,int64 execution_count)205 static bool ShouldBeMegamorphic(int64 compile_count, int64 execution_count) {
206   const int64 kCompileThreshold = 10;
207   const int64 kMinExecutionsPerCompile = 50;
208 
209   // This heuristic is trying to capture the following property: have we sunk a
210   // certain minimum amount of compile time into the cluster that didn't quite
211   // "pay off"?
212   return compile_count > kCompileThreshold &&
213          execution_count < kMinExecutionsPerCompile * compile_count;
214 }
215 
216 // Creates a simple graph using the specified op as the only op apart from the
217 // arg and retval nodes.
CreateGraph(const NodeDef & node_def,absl::Span<const XlaCompiler::Argument> args,absl::Span<const DataType> result_types)218 static xla::StatusOr<std::unique_ptr<Graph>> CreateGraph(
219     const NodeDef& node_def, absl::Span<const XlaCompiler::Argument> args,
220     absl::Span<const DataType> result_types) {
221   // TODO(b/74182462): We implement this by creating a new dummy Graph including
222   // _Arg nodes, and let CompileGraph walk it. This could be optimized.
223   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
224 
225   Status status;
226   // First create the actual node we care about computing.
227   Node* main_node = graph->AddNode(node_def, &status);
228   TF_RETURN_IF_ERROR(status);
229 
230   // Create dummy _Arg nodes. Link these to `node` and also via a control
231   // dependency edge to the _SOURCE node.
232   for (int64 i = 0, end = args.size(); i < end; ++i) {
233     Node* node;
234     string arg_name = absl::StrCat("_arg", i);
235     Status status =
236         NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
237             .ControlInput(graph->source_node())
238             .Attr("T", args[i].kind == XlaCompiler::Argument::kResource
239                            ? DT_RESOURCE
240                            : args[i].type)
241             .Attr("index", i)
242             .Finalize(graph.get(), &node);
243     TF_RETURN_IF_ERROR(status);
244     graph->AddEdge(node, 0, main_node, i);
245   }
246 
247   // Similarly with return values, create dummy _Retval nodes fed by `node`.
248   for (int64 i = 0, end = result_types.size(); i < end; ++i) {
249     Node* node;
250     string retval_name = absl::StrCat("_retval", i);
251     Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
252                         .Input(main_node, i)
253                         .Attr("T", result_types[i])
254                         .Attr("index", i)
255                         .Finalize(graph.get(), &node);
256     TF_RETURN_IF_ERROR(status);
257   }
258   FixupSourceAndSinkEdges(graph.get());
259   return graph;
260 }
261 
CompileSingleOp(const XlaCompiler::Options & options,absl::Span<const XlaCompiler::Argument> args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)262 Status XlaCompilationCache::CompileSingleOp(
263     const XlaCompiler::Options& options,
264     absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
265     const XlaCompiler::CompileOptions& compile_options,
266     const XlaCompiler::CompilationResult** out_compilation_result,
267     xla::LocalExecutable** out_executable) {
268   const NodeDef& def = ctx->op_kernel().def();
269   NameAttrList name;
270   name.set_name(def.op());
271   *name.mutable_attr() = def.attr();
272   // Remove the "_class" attribute from the attribute set used to create the
273   // compilation cache key. This attribute is information for the colocator
274   // and causes false uniqueness between nodes.
275   name.mutable_attr()->erase("_class");
276   auto compile_op = [&](XlaCompiler* compiler,
277                         XlaCompiler::CompilationResult* result) {
278     std::vector<DataType> result_dtypes(ctx->num_outputs());
279     for (int i = 0, end = result_dtypes.size(); i < end; ++i) {
280       result_dtypes[i] = ctx->expected_output_dtype(i);
281     }
282 
283     const NodeDef& node_def = ctx->op_kernel().def();
284     TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes));
285 
286     const ConfigProto* config = ctx->function_library()->config_proto();
287     // TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
288     bool use_mlir = config &&
289                     GetMlirBridgeRolloutPolicy(
290                         *graph, *config, /*uses_uninitialized_resource_args=*/
291                         AnyUninitializedResourceArg(args)) ==
292                         MlirBridgeRolloutPolicy::kEnabledByUser &&
293                     node_def.op() != "VarIsInitializedOp";
294     if (!use_mlir) {
295       return compiler->CompileGraph(compile_options, node_def.name(),
296                                     std::move(graph), args, result);
297     }
298 
299     VLOG(1) << "Using MLIR bridge";
300     GraphDebugInfo debug_info;
301     std::vector<std::string> control_rets;
302     if (result_dtypes.empty()) {
303       control_rets.push_back(node_def.name());
304     }
305     return CompileGraphToXlaHlo(
306         *graph, mlir::SpanToArrayRef<XlaCompiler::Argument>(args), control_rets,
307         options.device_type.type_string(), compile_options.use_tuple_arg,
308         *options.flib_def, debug_info, options.shape_representation_fn, result);
309   };
310   return CompileImpl(options, name, args, compile_op,
311                      /*compile_threshold=*/absl::nullopt,
312                      out_compilation_result, out_executable);
313 }
314 
315 namespace {
316 // Print something that users can search for to definitively ascertain that XLA
317 // was used for their TF model.
318 //
319 // Prints only once to avoid spamming LOG(INFO).
LogOnceXlaCompiledFirstCluster()320 void LogOnceXlaCompiledFirstCluster() {
321   static absl::once_flag log_once;
322   absl::call_once(log_once, [] {
323     LOG(INFO) << "Compiled cluster using XLA!  This line is logged at most "
324                  "once for the lifetime of the process.";
325   });
326 }
327 }  // namespace
328 
CompileImpl(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const std::function<Status (XlaCompiler * compiler,XlaCompiler::CompilationResult *)> & compile_fn,absl::optional<int64> compile_threshold,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)329 Status XlaCompilationCache::CompileImpl(
330     const XlaCompiler::Options& options, const NameAttrList& function,
331     absl::Span<const XlaCompiler::Argument> args,
332     const std::function<Status(XlaCompiler* compiler,
333                                XlaCompiler::CompilationResult*)>& compile_fn,
334     absl::optional<int64> compile_threshold,
335     const XlaCompiler::CompilationResult** out_compilation_result,
336     xla::LocalExecutable** out_executable) {
337   if (FailOnXlaCompilation()) {
338     return errors::Internal("XLA compilation disabled");
339   }
340 
341   DCHECK_NE(out_executable, nullptr);
342   VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
343 
344   if (VLOG_IS_ON(2)) {
345     VLOG(2) << "num_inputs=" << args.size();
346     for (int i = 0, end = args.size(); i < end; i++) {
347       VLOG(3) << i << ": " << args[i].HumanString();
348     }
349   }
350 
351   TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
352   VLOG(2) << "Signature: " << signature.HumanString();
353 
354   // The outer lock protects the existence of the cache entry. It does not
355   // protect the contents of the cache entry.
356   Entry* entry;
357   {
358     mutex_lock lock(compile_cache_mu_);
359     // Find or create a cache entry.
360     std::unique_ptr<Entry>& e = cache_[signature];
361     if (!e) {
362       e.reset(new Entry);
363     }
364     entry = e.get();
365   }
366 
367   // We always compile a cluster the very first time it is executed.  This is an
368   // optimistic guess that pays off for statically shaped TensorFlow graphs
369   // (since they get the benefit of XLA right away without waiting for warmup)
370   // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
371   // most one cluster-compilation's worth of compile time).
372   bool is_first_execution;
373 
374   // We avoid compiling clusters that have "gone megamorphic" i.e. have an
375   // excessive amount of shape dynamism.
376   bool is_megamorphic;
377 
378   {
379     mutex_lock lock(cluster_compile_stats_mu_);
380     auto it =
381         cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
382             .first;
383     is_first_execution = it->second.execution_count++ == 0;
384 
385     // The is_megamorphic bit is "sticky".  We assume clusters that have been
386     // observed to be megamorphic once stay megamorphic forever.
387     if (!it->second.is_megamorphic &&
388         ShouldBeMegamorphic(/*compile_count=*/it->second.compile_count,
389                             /*execution_count=*/it->second.execution_count)) {
390       VLOG(1) << "Marking " << function.name()
391               << " as megamorphic, compile_count=" << it->second.compile_count
392               << " execution_count=" << it->second.execution_count;
393       it->second.is_megamorphic = true;
394     }
395 
396     is_megamorphic = it->second.is_megamorphic;
397   }
398 
399   // Acquire the cache entry lock and compile, if necessary.
400   // TODO(phawkins): this locking will need to be restructured when we implement
401   // cache eviction.
402   mutex_lock entry_lock(entry->mu);
403   int64 current_request_count = ++entry->request_count;
404   VLOG(2) << "Compilation cache entry hit: " << entry->compiled
405           << " signature: " << signature.HumanString() << " with request count "
406           << current_request_count << " and compile threshold "
407           << compile_threshold.value_or(0);
408   if (!entry->compiled) {
409     XLA_SCOPED_LOGGING_TIMER("Compilation of XLA executable");
410     const bool should_compile = [&] {
411       if (!compile_threshold.has_value()) {
412         // Lazy compilation is disabled.
413         return true;
414       }
415 
416       if (is_megamorphic) {
417         BroadcastOptimizationRemark(XlaOptimizationRemark::MEGAMORPHIC_FUNCTION,
418                                     function.name())
419             .IgnoreError();
420         VLOG(3) << "Not compiling cluster " << function.name()
421                 << " because it is megamorphic.";
422         return false;
423       }
424 
425       if (is_first_execution) {
426         return true;
427       }
428 
429       bool reached_compile_threshold =
430           current_request_count >= *compile_threshold;
431       if (!reached_compile_threshold) {
432         VLOG(3)
433             << "Not compiling cluster " << function.name()
434             << " because it has not reached compile threshold; threshold is "
435             << *compile_threshold << " execution count "
436             << current_request_count << ".";
437       }
438       return reached_compile_threshold;
439     }();
440 
441     if (!should_compile) {
442       VLOG(2) << "Not compiling for signature: " << signature.HumanString();
443       *out_compilation_result = nullptr;
444       *out_executable = nullptr;
445       return Status::OK();
446     }
447 
448     tensorflow::Env* env = tensorflow::Env::Default();
449     const uint64 compile_start_us = env->NowMicros();
450     // Do the actual JIT compilation without holding the lock (it can take
451     // a long time.)
452 
453     XlaCompiler compiler(options);
454     entry->compiled = true;
455 
456     entry->compilation_status =
457         compile_fn(&compiler, &entry->compilation_result);
458     TF_RETURN_IF_ERROR(entry->compilation_status);
459     CHECK_EQ(entry->executable.get(), nullptr);
460     entry->compilation_status =
461         BuildExecutable(options, entry->compilation_result, &entry->executable);
462 
463     const uint64 compile_end_us = env->NowMicros();
464     const uint64 compile_time_us = compile_end_us - compile_start_us;
465     metrics::UpdateXlaCompilationTime(compile_time_us);
466     {
467       mutex_lock lock(cluster_compile_stats_mu_);
468       auto it = cluster_compile_stats_.find(function.name());
469       it->second.compile_count++;
470       it->second.cumulative_compile_time_us += compile_time_us;
471       LogOnceXlaCompiledFirstCluster();
472       VLOG(1) << "compiled " << function.name() << " "
473               << it->second.compile_count
474               << " times, compile time: " << compile_time_us
475               << " us, cumulative: " << it->second.cumulative_compile_time_us
476               << " us ("
477               << tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
478                                                                1.0e6)
479               << " / "
480               << tensorflow::strings::HumanReadableElapsedTime(
481                      it->second.cumulative_compile_time_us / 1.0e6)
482               << ")";
483 
484       XlaJitCompilationActivity jit_compilation_activity;
485       jit_compilation_activity.set_cluster_name(function.name());
486       jit_compilation_activity.set_compile_count(it->second.compile_count);
487       jit_compilation_activity.set_compile_time_us(compile_time_us);
488       jit_compilation_activity.set_cumulative_compile_time_us(
489           it->second.cumulative_compile_time_us);
490 
491       TF_RETURN_IF_ERROR(
492           BroadcastXlaActivity(std::move(jit_compilation_activity)));
493     }
494   }
495   TF_RETURN_IF_ERROR(entry->compilation_status);
496   *out_compilation_result = &entry->compilation_result;
497   *out_executable = entry->executable.get();
498   return Status::OK();
499 }
500 
501 }  // namespace tensorflow
502