1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
17
18 #include <numeric>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/tf2xla/type_util.h"
24 #include "tensorflow/compiler/tf2xla/xla_context.h"
25 #include "tensorflow/compiler/xla/client/client_library.h"
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/function.h"
28 #include "tensorflow/core/common_runtime/graph_optimizer.h"
29 #include "tensorflow/core/framework/attr_value_util.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/graph/graph_constructor.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/lib/hash/hash.h"
34 #include "tensorflow/core/platform/env.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/public/version.h"
37 #include "tensorflow/core/util/dump_graph.h"
38
39 namespace tensorflow {
40
41 constexpr int64 XlaCompilationCache::kDefaultCompilationThreshold;
42
XlaCompilationCache(xla::LocalClient * client,DeviceType device_type)43 XlaCompilationCache::XlaCompilationCache(xla::LocalClient* client,
44 DeviceType device_type)
45 : client_(client), device_type_(std::move(device_type)) {}
46
~XlaCompilationCache()47 XlaCompilationCache::~XlaCompilationCache() {
48 // Ensure any use of our programs have completed by waiting for all stream
49 // executors to complete.
50 for (auto* executor : client_->backend().stream_executors()) {
51 bool ok = executor->SynchronizeAllActivity();
52 if (!ok) {
53 LOG(ERROR) << "Error synchronizing activity while waiting for all "
54 "programs to complete";
55 }
56 }
57 // TODO(b/110813685): Think about the program ownership model. Programs are
58 // currently owned by the compilation cache which means we must wait for
59 // program completion in the destructor. There are multiple compilation caches
60 // around, which complicates things a little. Perhaps having programs be
61 // shared_ptrs (an invasive change) would make the model easier to reason
62 // about?
63 }
64
DebugString() const65 string XlaCompilationCache::DebugString() const {
66 return "XLA JIT compilation cache";
67 }
68
69 // Compute a string signature which encodes the shapes of the
70 // arguments in the supplied list.
HumanString() const71 string XlaCompilationCache::Signature::HumanString() const {
72 string result = name;
73 for (const auto& a : arg_shapes) {
74 absl::StrAppend(&result, ",", DataTypeString(a.first));
75 absl::StrAppend(&result, " [", absl::StrJoin(a.second, ","), "]");
76 }
77
78 for (const auto& v : arg_values) {
79 absl::StrAppend(&result, "; ", v.DebugString());
80 }
81 return result;
82 }
83
operator ==(const Signature & other) const84 bool XlaCompilationCache::Signature::operator==(const Signature& other) const {
85 if (name != other.name) return false;
86 if (arg_shapes != other.arg_shapes) return false;
87
88 if (arg_values.size() != other.arg_values.size()) return false;
89 for (int i = 0; i < arg_values.size(); ++i) {
90 if (arg_values[i].dtype() != other.arg_values[i].dtype() ||
91 arg_values[i].shape() != other.arg_values[i].shape() ||
92 arg_values[i].tensor_data() != other.arg_values[i].tensor_data()) {
93 return false;
94 }
95 }
96 return true;
97 }
98
operator ()(const XlaCompilationCache::Signature & signature) const99 uint64 XlaCompilationCache::Signature::Hash::operator()(
100 const XlaCompilationCache::Signature& signature) const {
101 uint64 h = std::hash<string>()(signature.name);
102 for (const auto& arg : signature.arg_shapes) {
103 h = Hash64Combine(h, std::hash<int>()(static_cast<int>(arg.first)));
104 h = Hash64Combine(h, std::hash<int>()(arg.second.size()));
105 for (int dim : arg.second) {
106 h = Hash64Combine(h, std::hash<int>()(dim));
107 }
108 }
109 for (const auto& arg : signature.arg_values) {
110 h = Hash64Combine(
111 h, Hash64(arg.tensor_data().data(), arg.tensor_data().size()));
112 }
113 return h;
114 }
115
116 xla::StatusOr<XlaCompilationCache::Signature>
BuildSignature(const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args)117 XlaCompilationCache::BuildSignature(
118 const NameAttrList& function,
119 absl::Span<const XlaCompiler::Argument> args) {
120 Signature signature;
121 signature.name = Canonicalize(function.name(), AttrSlice(&function.attr()));
122 for (const XlaCompiler::Argument& arg : args) {
123 switch (arg.kind) {
124 case XlaCompiler::Argument::kConstant:
125 signature.arg_values.push_back(arg.constant_value);
126 break;
127 case XlaCompiler::Argument::kParameter:
128 case XlaCompiler::Argument::kResource:
129 signature.arg_shapes.emplace_back(arg.type, arg.DimensionSizes());
130 break;
131 default:
132 return errors::InvalidArgument(
133 "Unhandled argument kind in XlaCompilationCache: ",
134 arg.HumanString());
135 }
136 }
137 return std::move(signature);
138 }
139
BuildExecutable(const XlaCompiler::Options & options,const XlaCompiler::CompilationResult & result,std::unique_ptr<xla::LocalExecutable> * executable)140 Status XlaCompilationCache::BuildExecutable(
141 const XlaCompiler::Options& options,
142 const XlaCompiler::CompilationResult& result,
143 std::unique_ptr<xla::LocalExecutable>* executable) {
144 VLOG(2) << "Compiling to local executable";
145
146 std::vector<const xla::Shape*> argument_layouts(
147 result.xla_input_shapes.size());
148 for (int i = 0; i < result.xla_input_shapes.size(); ++i) {
149 argument_layouts[i] = &result.xla_input_shapes[i];
150 }
151 xla::ExecutableBuildOptions build_options;
152 build_options.set_device_ordinal(options.device_ordinal != -1
153 ? options.device_ordinal
154 : client_->default_device_ordinal());
155 build_options.set_result_layout(result.xla_output_shape);
156 build_options.set_device_allocator(options.device_allocator);
157
158 auto compile_result =
159 client_->Compile(*result.computation, argument_layouts, build_options);
160 if (!compile_result.ok()) {
161 return compile_result.status();
162 }
163 *executable = std::move(compile_result.ValueOrDie());
164 return Status::OK();
165 }
166
Compile(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const XlaCompiler::CompileOptions & compile_options,CompileMode compile_mode,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)167 Status XlaCompilationCache::Compile(
168 const XlaCompiler::Options& options, const NameAttrList& function,
169 absl::Span<const XlaCompiler::Argument> args,
170 const XlaCompiler::CompileOptions& compile_options,
171 CompileMode compile_mode,
172 const XlaCompiler::CompilationResult** out_compilation_result,
173 xla::LocalExecutable** out_executable) {
174 absl::optional<int64> compile_threshold;
175 if (compile_mode == CompileMode::kLazy) {
176 compile_threshold = kDefaultCompilationThreshold;
177 }
178 auto compile_fn = [&](XlaCompiler* compiler,
179 XlaCompiler::CompilationResult* result) {
180 return compiler->CompileFunction(compile_options, function, args, result);
181 };
182 return CompileImpl(options, function, args, compile_fn,
183 /*compile_threshold=*/compile_threshold,
184 out_compilation_result, out_executable);
185 }
186
IsMegamorphic(int64 compile_count,int64 execution_count)187 static bool IsMegamorphic(int64 compile_count, int64 execution_count) {
188 const int64 kCompileThreshold = 10;
189 const int64 kMinExecutionsPerCompile = 50;
190
191 // This heuristic is trying to capture the following property: have we sunk a
192 // certain minimum amount of compile time into the cluster that didn't quite
193 // "pay off"?
194 return compile_count > kCompileThreshold &&
195 execution_count < kMinExecutionsPerCompile * compile_count;
196 }
197
CompileSingleOp(const XlaCompiler::Options & options,absl::Span<const XlaCompiler::Argument> args,OpKernelContext * ctx,const XlaCompiler::CompileOptions & compile_options,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)198 Status XlaCompilationCache::CompileSingleOp(
199 const XlaCompiler::Options& options,
200 absl::Span<const XlaCompiler::Argument> args, OpKernelContext* ctx,
201 const XlaCompiler::CompileOptions& compile_options,
202 const XlaCompiler::CompilationResult** out_compilation_result,
203 xla::LocalExecutable** out_executable) {
204 const NodeDef& def = ctx->op_kernel().def();
205 NameAttrList name;
206 name.set_name(def.op());
207 *name.mutable_attr() = def.attr();
208 // Remove the "_class" attribute from the attribute set used to create the
209 // compilation cache key. This attribute is information for the colocator
210 // and causes false uniqueness between nodes.
211 name.mutable_attr()->erase("_class");
212 auto compile_op = [&](XlaCompiler* compiler,
213 XlaCompiler::CompilationResult* result) {
214 std::vector<DataType> result_dtypes(ctx->num_outputs());
215 for (int i = 0; i < result_dtypes.size(); ++i) {
216 result_dtypes[i] = ctx->expected_output_dtype(i);
217 }
218 return compiler->CompileSingleOp(compile_options, ctx->op_kernel().def(),
219 args, result_dtypes, result);
220 };
221 return CompileImpl(options, name, args, compile_op,
222 /*compile_threshold=*/absl::nullopt,
223 out_compilation_result, out_executable);
224 }
225
CompileImpl(const XlaCompiler::Options & options,const NameAttrList & function,absl::Span<const XlaCompiler::Argument> args,const std::function<Status (XlaCompiler * compiler,XlaCompiler::CompilationResult *)> & compile_fn,absl::optional<int64> compile_threshold,const XlaCompiler::CompilationResult ** out_compilation_result,xla::LocalExecutable ** out_executable)226 Status XlaCompilationCache::CompileImpl(
227 const XlaCompiler::Options& options, const NameAttrList& function,
228 absl::Span<const XlaCompiler::Argument> args,
229 const std::function<Status(XlaCompiler* compiler,
230 XlaCompiler::CompilationResult*)>& compile_fn,
231 absl::optional<int64> compile_threshold,
232 const XlaCompiler::CompilationResult** out_compilation_result,
233 xla::LocalExecutable** out_executable) {
234 DCHECK_NE(out_executable, nullptr);
235 VLOG(2) << "XlaCompilationCache::Compile " << DebugString();
236
237 if (VLOG_IS_ON(2)) {
238 VLOG(2) << "num_inputs=" << args.size();
239 for (int i = 0; i < args.size(); i++) {
240 VLOG(2) << i << ": " << args[i].HumanString();
241 }
242 }
243
244 TF_ASSIGN_OR_RETURN(Signature signature, BuildSignature(function, args));
245 VLOG(2) << "Signature: " << signature.HumanString();
246
247 // The outer lock protects the existence of the cache entry. It does not
248 // protect the contents of the cache entry.
249 Entry* entry;
250 {
251 mutex_lock lock(compile_cache_mu_);
252 // Find or create a cache entry.
253 std::unique_ptr<Entry>& e = cache_[signature];
254 if (!e) {
255 e.reset(new Entry);
256 }
257 entry = e.get();
258 }
259
260 // We always compile a cluster the very first time it is executed. This is an
261 // optimistic guess that pays off for statically shaped TensorFlow graphs
262 // (since they get the benefit of XLA right away without waiting for warmup)
263 // and doesn't hurt much for dynamically shaped TensorFlow graphs (we "pay" at
264 // most one cluster-compilation's worth of compile time).
265 bool is_first_execution;
266
267 // We avoid compiling clusters that have "gone megamorphic" i.e. have an
268 // excessive amount of shape dynamism.
269 bool is_megamorphic;
270
271 {
272 mutex_lock lock(cluster_compile_stats_mu_);
273 auto it =
274 cluster_compile_stats_.emplace(function.name(), ClusterCompileStats{})
275 .first;
276 is_first_execution = it->second.execution_count++ == 0;
277
278 // The is_megamorphic bit is "sticky". We assume clusters that have been
279 // observed to be megamorphic once stay megamorphic forever.
280 it->second.is_megamorphic |=
281 IsMegamorphic(/*compile_count=*/it->second.compile_count,
282 /*execution_count=*/it->second.execution_count);
283 is_megamorphic = it->second.is_megamorphic;
284 }
285
286 // Acquire the cache entry lock and compile, if necessary.
287 // TODO(phawkins): this locking will need to be restructured when we implement
288 // cache eviction.
289 mutex_lock entry_lock(entry->mu);
290 int64 current_request_count = ++entry->request_count;
291 VLOG(2) << "Compilation cache entry hit: " << entry->compiled
292 << " signature: " << signature.HumanString() << " with request count "
293 << current_request_count << " and compile threshold "
294 << compile_threshold.value_or(0);
295 if (!entry->compiled) {
296 const bool should_compile = [&] {
297 if (!compile_threshold.has_value()) {
298 // Lazy compilation is disabled.
299 return true;
300 }
301
302 if (is_megamorphic) {
303 VLOG(3) << "Not compiling cluster " << function.name()
304 << " because it is megamorphic.";
305 return false;
306 }
307
308 if (is_first_execution) {
309 return true;
310 }
311
312 bool reached_compile_threshold =
313 current_request_count >= *compile_threshold;
314 if (!reached_compile_threshold) {
315 VLOG(3)
316 << "Not compiling cluster " << function.name()
317 << " because it has not reached compile threshold; threshold is "
318 << *compile_threshold << " execution count "
319 << current_request_count << ".";
320 }
321 return reached_compile_threshold;
322 }();
323
324 if (!should_compile) {
325 VLOG(2) << "Not compiling for signature: " << signature.HumanString();
326 *out_compilation_result = nullptr;
327 *out_executable = nullptr;
328 return Status::OK();
329 }
330
331 tensorflow::Env* env = tensorflow::Env::Default();
332 const uint64 compile_start_us = env->NowMicros();
333 // Do the actual JIT compilation without holding the lock (it can take
334 // a long time.)
335
336 XlaCompiler compiler(options);
337 entry->compiled = true;
338
339 entry->compilation_status =
340 compile_fn(&compiler, &entry->compilation_result);
341 TF_RETURN_IF_ERROR(entry->compilation_status);
342 CHECK_EQ(entry->executable.get(), nullptr);
343 entry->compilation_status =
344 BuildExecutable(options, entry->compilation_result, &entry->executable);
345
346 const uint64 compile_end_us = env->NowMicros();
347 const uint64 compile_time_us = compile_end_us - compile_start_us;
348 {
349 mutex_lock lock(cluster_compile_stats_mu_);
350 auto it = cluster_compile_stats_.find(function.name());
351 it->second.compile_count++;
352 it->second.cumulative_compile_time_us += compile_time_us;
353 VLOG(1) << "compiled " << function.name() << " "
354 << it->second.compile_count
355 << " times, compile time: " << compile_time_us
356 << " us, cumulative: " << it->second.cumulative_compile_time_us
357 << " us ("
358 << tensorflow::strings::HumanReadableElapsedTime(compile_time_us /
359 1.0e6)
360 << " / "
361 << tensorflow::strings::HumanReadableElapsedTime(
362 it->second.cumulative_compile_time_us / 1.0e6)
363 << ")";
364 }
365 }
366 TF_RETURN_IF_ERROR(entry->compilation_status);
367 *out_compilation_result = &entry->compilation_result;
368 *out_executable = entry->executable.get();
369 return Status::OK();
370 }
371
372 } // namespace tensorflow
373