• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17 
18 #include <numeric>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_optimizer.h"
29 #include "tensorflow/core/framework/attr_value_util.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/lib/hash/hash.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/public/version.h"
37 #include "tensorflow/core/util/dump_graph.h"
38 
39 namespace tensorflow {
40 
41 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
42 
XlaCompilationCache(xla::LocalClient * client,DeviceType device_type)43 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
44                                          DeviceType device_type)
45     : client_(client), device_type_(std::move(device_type)) {}
46 
~XlaCompilationCache()47 XlaCompilationCache::~XlaCompilationCache() {
48   // Ensure any use of our programs have completed by waiting for all stream
49   // executors to complete.
50   for (auto* executor : client_->backend().stream_executors()) {
51     bool ok = executor->SynchronizeAllActivity();
52     if (!ok) {
53       LOG(ERROR) << "Error synchronizing activity while waiting for all "
54                     "programs to complete";
55     }
56   }
57   // TODO(b/110813685): Think about the program ownership model. Programs are
58   // currently owned by the compilation cache which means we must wait for
59   // program completion in the destructor. There are multiple compilation caches
60   // around, which complicates things a little. Perhaps having programs be
61   // shared_ptrs (an invasive change) would make the model easier to reason
62   // about?
63 }
64 
DebugString() const65 string XlaCompilationCache::DebugString() const {
66   return "XLA JIT compilation cache";
67 }
68 
69 // Compute a string signature which encodes the shapes of the
70 // arguments in the supplied list.
HumanString() const71 string XlaCompilationCache::Signature::HumanString() const {
72   string result = name;
73   for (const auto& a : arg_shapes) {
74     absl::StrAppend(&result, ",", DataTypeString(a.first));
75     absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
76   }
77 
78   for (const auto& v : arg_values) {
79     absl::StrAppend(&result, "; ", v.DebugString());
80   }
81   return result;
82 }
83 
operator ==(const Signature & other) const84 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
85   if (name != other.name) return false;
86   if (arg_shapes != other.arg_shapes) return false;
87 
88   if (arg_values.size() != other.arg_values.size()) return false;
89   for (int i = 0; i < arg_values.size(); ++i) {
90     if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
91         arg_values[i].shape() != other.arg_values[i].shape() ||
92         arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
93       return false;
94     }
95   }
96   return true;
97 }
98 
operator ()(const XlaCompilationCache::Signature & signature) const99 uint64 XlaCompilationCache::Signature::Hash::operator()(
100     const XlaCompilationCache::Signature& signature) const {
101   uint64 h = std::hash<string>()(signature.name);
102   for (const auto& arg : signature.arg_shapes) {
103     h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
104     h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
105     for (int dim : arg.second) {
106       h = Hash64Combine(h, std::hash<int>()(dim));
107     }
108   }
109   for (const auto& arg : signature.arg_values) {
110     h = Hash64Combine(
111         h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
112   }
113   return h;
114 }
115 
116 xla::StatusOr<XlaCompilationCache::Signature>
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)117 XlaCompilationCache::BuildSignature(
118     const NameAttrList& function,
119     absl::Span<const XlaCompiler::Argument> args) {
120   Signature signature;
121   signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
122   for (const XlaCompiler::Argument& arg : args) {
123     switch (arg.kind) {
124       case XlaCompiler::Argument::kConstant:
125         signature.arg_values.push_back(arg.constant_value);
126         break;
127       case XlaCompiler::Argument::kParameter:
128       case XlaCompiler::Argument::kResource:
129         signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes());
130         break;
131       default:
132         return errors::InvalidArgument(
133             "Unhandled argument kind in XlaCompilationCache: ",
134             arg.HumanString());
135     }
136   }
137   return std::move(signature);
138 }
139 
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)140 Status XlaCompilationCache::BuildExecutable(
141     const XlaCompiler::Options& options,
142     const XlaCompiler::CompilationResult& result,
143     std::unique_ptr<xla::LocalExecutable>* executable) {
144   VLOG(2) << "Compiling to local executable";
145 
146   std::vector<const xla::Shape*> argument_layouts(
147       result.xla_input_shapes.size());
148   for (int i = 0; i < result.xla_input_shapes.size(); ++i) {
149     argument_layouts[i] = &result.xla_input_shapes[i];
150   }
151   xla::ExecutableBuildOptions build_options;
152   build_options.set_device_ordinal(options.device_ordinal != -1
153                                        ? options.device_ordinal
154                                        : client_->default_device_ordinal());
155   build_options.set_result_layout(result.xla_output_shape);
156   build_options.set_device_allocator(options.device_allocator);
157 
158   auto compile_result =
159       client_->Compile(*result.computation, argument_layouts, build_options);
160   if (!compile_result.ok()) {
161     return compile_result.status();
162   }
163   *executable = std::move(compile_result.ValueOrDie());
164   return Status::OK();
165 }
166 
Compile(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)167 Status XlaCompilationCache::Compile(
168     const XlaCompiler::Options& options, const NameAttrList& function,
169     absl::Span<const XlaCompiler::Argument> args,
170     const XlaCompiler::CompileOptions& compile_options,
171     CompileMode compile_mode,
172     const XlaCompiler::CompilationResult** out_compilation_result,
173     xla::LocalExecutable** out_executable) {
174   absl::optional<int64> compile_threshold;
175   if (compile_mode == CompileMode::kLazy) {
176     compile_threshold = kDefaultCompilationThreshold;
177   }
178   auto compile_fn = [&](XlaCompiler* compiler,
179                         XlaCompiler::CompilationResult* result) {
180     return compiler->CompileFunction(compile_options, function, args, result);
181   };
182   return CompileImpl(options, function, args, compile_fn,
183                      /*compile_threshold=*/compile_threshold,
184                      out_compilation_result, out_executable);
185 }
186 
IsMegamorphic(int64 compile_count,int64 execution_count)187 static bool IsMegamorphic(int64 compile_count, int64 execution_count) {
188   const int64 kCompileThreshold = 10;
189   const int64 kMinExecutionsPerCompile = 50;
190 
191   // This heuristic is trying to capture the following property: have we sunk a
192   // certain minimum amount of compile time into the cluster that didn't quite
193   // "pay off"?
194   return compile_count > kCompileThreshold &&
195          execution_count < kMinExecutionsPerCompile * compile_count;
196 }
197 
CompileSingleOp(const XlaCompiler::Options & options,absl::Span<const XlaCompiler::Argument> args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)198 Status XlaCompilationCache::CompileSingleOp(
199     const XlaCompiler::Options& options,
200     absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
201     const XlaCompiler::CompileOptions& compile_options,
202     const XlaCompiler::CompilationResult** out_compilation_result,
203     xla::LocalExecutable** out_executable) {
204   const NodeDef& def = ctx->op_kernel().def();
205   NameAttrList name;
206   name.set_name(def.op());
207   *name.mutable_attr() = def.attr();
208   // Remove the "_class" attribute from the attribute set used to create the
209   // compilation cache key. This attribute is information for the colocator
210   // and causes false uniqueness between nodes.
211   name.mutable_attr()->erase("_class");
212   auto compile_op = [&](XlaCompiler* compiler,
213                         XlaCompiler::CompilationResult* result) {
214     std::vector<DataType> result_dtypes(ctx->num_outputs());
215     for (int i = 0; i < result_dtypes.size(); ++i) {
216       result_dtypes[i] = ctx->expected_output_dtype(i);
217     }
218     return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
219                                      args, result_dtypes, result);
220   };
221   return CompileImpl(options, name, args, compile_op,
222                      /*compile_threshold=*/absl::nullopt,
223                      out_compilation_result, out_executable);
224 }
225 
CompileImpl(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const std::function<Status (XlaCompiler * compiler,XlaCompiler::CompilationResult *)> & compile_fn,absl::optional<int64> compile_threshold,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)226 Status XlaCompilationCache::CompileImpl(
227     const XlaCompiler::Options& options, const NameAttrList& function,
228     absl::Span<const XlaCompiler::Argument> args,
229     const std::function<Status(XlaCompiler* compiler,
230                                XlaCompiler::CompilationResult*)>& compile_fn,
231     absl::optional<int64> compile_threshold,
232     const XlaCompiler::CompilationResult** out_compilation_result,
233     xla::LocalExecutable** out_executable) {
234   DCHECK_NE(out_executable, nullptr);
235   VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
236 
237   if (VLOG_IS_ON(2)) {
238     VLOG(2) << "num_inputs=" << args.size();
239     for (int i = 0; i < args.size(); i++) {
240       VLOG(2) << i << ": " << args[i].HumanString();
241     }
242   }
243 
244   TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
245   VLOG(2) << "Signature: " << signature.HumanString();
246 
247   // The outer lock protects the existence of the cache entry. It does not
248   // protect the contents of the cache entry.
249   Entry* entry;
250   {
251     mutex_lock lock(compile_cache_mu_);
252     // Find or create a cache entry.
253     std::unique_ptr<Entry>& e = cache_[signature];
254     if (!e) {
255       e.reset(new Entry);
256     }
257     entry = e.get();
258   }
259 
260   // We always compile a cluster the very first time it is executed.  This is an
261   // optimistic guess that pays off for statically shaped TensorFlow graphs
262   // (since they get the benefit of XLA right away without waiting for warmup)
263   // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
264   // most one cluster-compilation's worth of compile time).
265   bool is_first_execution;
266 
267   // We avoid compiling clusters that have "gone megamorphic" i.e. have an
268   // excessive amount of shape dynamism.
269   bool is_megamorphic;
270 
271   {
272     mutex_lock lock(cluster_compile_stats_mu_);
273     auto it =
274         cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
275             .first;
276     is_first_execution = it->second.execution_count++ == 0;
277 
278     // The is_megamorphic bit is "sticky".  We assume clusters that have been
279     // observed to be megamorphic once stay megamorphic forever.
280     it->second.is_megamorphic |=
281         IsMegamorphic(/*compile_count=*/it->second.compile_count,
282                       /*execution_count=*/it->second.execution_count);
283     is_megamorphic = it->second.is_megamorphic;
284   }
285 
286   // Acquire the cache entry lock and compile, if necessary.
287   // TODO(phawkins): this locking will need to be restructured when we implement
288   // cache eviction.
289   mutex_lock entry_lock(entry->mu);
290   int64 current_request_count = ++entry->request_count;
291   VLOG(2) << "Compilation cache entry hit: " << entry->compiled
292           << " signature: " << signature.HumanString() << " with request count "
293           << current_request_count << " and compile threshold "
294           << compile_threshold.value_or(0);
295   if (!entry->compiled) {
296     const bool should_compile = [&] {
297       if (!compile_threshold.has_value()) {
298         // Lazy compilation is disabled.
299         return true;
300       }
301 
302       if (is_megamorphic) {
303         VLOG(3) << "Not compiling cluster " << function.name()
304                 << " because it is megamorphic.";
305         return false;
306       }
307 
308       if (is_first_execution) {
309         return true;
310       }
311 
312       bool reached_compile_threshold =
313           current_request_count >= *compile_threshold;
314       if (!reached_compile_threshold) {
315         VLOG(3)
316             << "Not compiling cluster " << function.name()
317             << " because it has not reached compile threshold; threshold is "
318             << *compile_threshold << " execution count "
319             << current_request_count << ".";
320       }
321       return reached_compile_threshold;
322     }();
323 
324     if (!should_compile) {
325       VLOG(2) << "Not compiling for signature: " << signature.HumanString();
326       *out_compilation_result = nullptr;
327       *out_executable = nullptr;
328       return Status::OK();
329     }
330 
331     tensorflow::Env* env = tensorflow::Env::Default();
332     const uint64 compile_start_us = env->NowMicros();
333     // Do the actual JIT compilation without holding the lock (it can take
334     // a long time.)
335 
336     XlaCompiler compiler(options);
337     entry->compiled = true;
338 
339     entry->compilation_status =
340         compile_fn(&compiler, &entry->compilation_result);
341     TF_RETURN_IF_ERROR(entry->compilation_status);
342     CHECK_EQ(entry->executable.get(), nullptr);
343     entry->compilation_status =
344         BuildExecutable(options, entry->compilation_result, &entry->executable);
345 
346     const uint64 compile_end_us = env->NowMicros();
347     const uint64 compile_time_us = compile_end_us - compile_start_us;
348     {
349       mutex_lock lock(cluster_compile_stats_mu_);
350       auto it = cluster_compile_stats_.find(function.name());
351       it->second.compile_count++;
352       it->second.cumulative_compile_time_us += compile_time_us;
353       VLOG(1) << "compiled " << function.name() << " "
354               << it->second.compile_count
355               << " times, compile time: " << compile_time_us
356               << " us, cumulative: " << it->second.cumulative_compile_time_us
357               << " us ("
358               << tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
359                                                                1.0e6)
360               << " / "
361               << tensorflow::strings::HumanReadableElapsedTime(
362                      it->second.cumulative_compile_time_us / 1.0e6)
363               << ")";
364     }
365   }
366   TF_RETURN_IF_ERROR(entry->compilation_status);
367   *out_compilation_result = &entry->compilation_result;
368   *out_executable = entry->executable.get();
369   return Status::OK();
370 }
371 
372 }  // namespace tensorflow
373