1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
16
17 #include <string>
18
19 #include "absl/strings/string_view.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/compiler/jit/flags.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/common_runtime/graph_optimizer.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/metrics.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
31 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
32 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
33 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
37 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
38 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
39 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
40 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
41 #include "tensorflow/core/tpu/kernels/tpu_util.h"
42 #include "tensorflow/core/tpu/tpu_api.h"
43 #include "tensorflow/core/tpu/tpu_compile_interface.h"
44 #include "tensorflow/core/tpu/tpu_configuration.h"
45 #include "tensorflow/core/tpu/tpu_defs.h"
46 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
47
48 namespace tensorflow {
49 namespace tpu {
50
51 CompileOpImplFactory* CompileOpImplFactory::factory_ = nullptr;
52
53 /* static */
Get()54 CompileOpImplFactory* CompileOpImplFactory::Get() { return factory_; }
55
56 /* static */
Register(CompileOpImplFactory * factory)57 void CompileOpImplFactory::Register(CompileOpImplFactory* factory) {
58 CHECK_EQ(factory_, nullptr)
59 << "CompileOpImplFactory can only be registered "
60 "once and there can only be one factory active and used.";
61 factory_ = factory;
62 }
63
ExitCountdown(Env * env,std::shared_ptr<std::atomic<bool>> done)64 /* static */ void TpuCompileOpKernelCommon::ExitCountdown(
65 Env* env, std::shared_ptr<std::atomic<bool>> done) {
66 const int kSleepSeconds = 300;
67 LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
68 << " seconds to give time for TPUCompileOp to finished.";
69 env->SleepForMicroseconds(kSleepSeconds * 1000000);
70 if (done->load()) {
71 // If the TpuCompileOp has finished, then terminate peacefully.
72 return;
73 }
74
75 LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
76 << "termination is to ensure a consistent state.";
77 std::exit(42);
78 }
79
GetDynamicShapes(OpKernelContext * ctx,std::vector<TensorShape> * shapes)80 /* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
81 OpKernelContext* ctx, std::vector<TensorShape>* shapes) {
82 OpInputList dynamic_shapes;
83 TF_RETURN_IF_ERROR(ctx->input_list("dynamic_shapes", &dynamic_shapes));
84
85 shapes->resize(dynamic_shapes.size());
86 for (int i = 0; i < dynamic_shapes.size(); ++i) {
87 TF_RETURN_IF_ERROR(
88 tpu::ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
89 }
90 return Status::OK();
91 }
92
Compute(OpKernelContext * ctx)93 void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
94 VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
95
96 std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
97
98 CancellationToken token =
99 ctx->cancellation_manager()->get_cancellation_token();
100 const bool already_cancelled =
101 !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
102 if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
103 return;
104 }
105
106 // Sleep and exit in another thread so the cancellation manager can
107 // continue running callbacks.
108 Env* env = ctx->env();
109 env->SchedClosure([env, done]() { ExitCountdown(env, done); });
110 });
111
112 // If the RPC was cancelled before we registered the cancellation callback,
113 // don't compile the TPU program.
114 OP_REQUIRES(ctx, !already_cancelled,
115 errors::Cancelled("RPC cancelled, not compiling TPU program"));
116
117 // We only want to abort the process if a cancellation actually occurs during
118 // compilation; we must deregister the callback in the success case. It
119 // doesn't hurt to also deregister the callback in the failure case; the
120 // CancellationManager ensures that already-registered callbacks will be run
121 // once cancellation has started.
122 auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
123 ctx->cancellation_manager()->DeregisterCallback(token);
124 done->store(true);
125 });
126
127 Status compile_status = ComputeInternal(ctx);
128 string status_payload;
129 // Construct payload if compile_status is not ok and there's no payload for
130 // compilation yet.
131 if (!compile_status.ok() &&
132 compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey)
133 .empty()) {
134 tpu::CompilationResultProto proto;
135 proto.set_status_code(compile_status.code());
136 proto.set_status_error_message(compile_status.error_message());
137 status_payload = proto.SerializeAsString();
138 }
139 OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx,
140 TpuCompileInterface::kTpuCompileErrorPayloadKey,
141 status_payload, compile_status);
142 }
143
CompileLocallyAndFillHostCache(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,const TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)144 Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
145 FunctionLibraryRuntime* flib_runtime,
146 const SessionMetadata* session_metadata,
147 const TpuMeshStateInterface* mesh_state,
148 const std::vector<TensorShape>& dynamic_shapes,
149 const OpInputList& guaranteed_constants, const TpuCompilationCacheKey& key,
150 TpuProgramGroupInterface* tpu_program_group) {
151 absl::Time start_time = absl::Now();
152 std::vector<TensorShape> arg_shapes;
153 TF_RETURN_IF_ERROR(
154 ComputeArgumentShapes(metadata_, dynamic_shapes, &arg_shapes));
155 Status compile_status;
156 if (use_mlir_) {
157 const ConfigProto* config = flib_runtime->config_proto();
158 ConfigProto::Experimental::MlirBridgeRollout rollout_state =
159 GetMlirBridgeRolloutState(config ? absl::make_optional(*config)
160 : absl::nullopt);
161 compile_status = Compile(MlirToHloArgs{mlir_module_, rollout_state},
162 mesh_state->data(), arg_shapes, tpu_program_group);
163 } else {
164 compile_status =
165 Compile(FunctionToHloArgs{&function_,
166 flib_runtime->GetFunctionLibraryDefinition(),
167 flib_runtime->graph_def_version(),
168 {&guaranteed_constants}},
169 mesh_state->data(), arg_shapes, tpu_program_group);
170 }
171
172 absl::Time end_time = absl::Now();
173 auto duration = end_time - start_time;
174
175 const std::string session_name = SessionNameFromMetadata(session_metadata);
176 LOG(INFO) << "Compilation of " << key.prefix << " with session name "
177 << session_name << " took " << duration << " and "
178 << (compile_status.ok() ? "succeeded" : "failed");
179 tpu_program_group->LogProgramMemorySummary();
180 metrics::UpdateXlaCompilationTime(absl::ToInt64Microseconds(duration));
181 TpuCompilationMetrics::IncrementCompilationCount(session_name);
182
183 TF_RETURN_IF_ERROR(tpu_program_group->LogCompilationStats(key, duration));
184
185 return compile_status;
186 }
187
ComputeInternal(OpKernelContext * ctx)188 Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
189 VLOG(1) << "Retrieving mesh state";
190 // Retrieve the topology from the resource manager
191 ResourceMgr* rm = GetTPUConfigResourceMgr();
192
193 TpuMeshStateInterface* mesh_state;
194 TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
195 kTpuMeshStateInterfaceResourceName,
196 &mesh_state));
197 core::ScopedUnref mesh_state_unref(mesh_state);
198
199 std::vector<TensorShape> dynamic_shapes;
200 TF_RETURN_IF_ERROR(GetDynamicShapes(ctx, &dynamic_shapes));
201
202 OpInputList guaranteed_constants;
203 // TODO(ycao): Decide whether/how to support guaranteed constants in
204 // MLIR-based TF-Compiler Bridge.
205 if (!use_mlir_) {
206 TF_RETURN_IF_ERROR(
207 ctx->input_list("guaranteed_constants", &guaranteed_constants));
208 }
209
210 const TpuCompilationCacheKey key = CreateCompilationCacheKey(
211 function_.name(), metadata_.function_library_fingerprint(),
212 mlir_module_fingerprint_, guaranteed_constants, dynamic_shapes, metadata_,
213 *mesh_state);
214
215 // Process-wide cache of TPU executables.
216 TpuCompilationCacheInterface* cache;
217 TF_RETURN_IF_ERROR(rm->Lookup<TpuCompilationCacheInterface>(
218 rm->default_container(), kCompilationCacheResourceName, &cache));
219 core::ScopedUnref cache_unref(cache);
220
221 // Per-step object that ensures that compilation cache entries aren't
222 // evicted until the step completes. This mechanism ensures that the
223 // downstream TPUExecute Ops in this step will be able to look up the
224 // compiled executable even if it is marked for eviction before the step
225 // ends.
226 //
227 // We can't use GetTPUConfigResourceMgr here because it may return the
228 // global ResourceMgr, which is not associated with any device, and
229 // GraphMgr's ScopedStepContainer only searches ResourceMgrs associated
230 // with devices when deleting resources at step boundaries.
231 CompilationRefHolder* ref_holder;
232 if (ctx->step_container() == nullptr) {
233 return errors::FailedPrecondition(
234 "TPUCompileOp requires a step container.");
235 }
236 TF_RETURN_IF_ERROR(
237 ctx->step_container()->LookupOrCreate<CompilationRefHolder>(
238 ctx->resource_manager(), "ref_holder", &ref_holder,
239 [cache](CompilationRefHolder** h) {
240 *h = cache->MakePerStepRefHolder();
241 return Status::OK();
242 }));
243 core::ScopedUnref ref_holder_unref(ref_holder);
244
245 int64_t uid;
246 std::vector<std::string> proto_key;
247 std::vector<std::string> sharding_key;
248 std::vector<bool> may_modify_variables;
249 absl::Span<const xla::HloProto* const> hlo_metadatas;
250 Status status = cache->CompileIfKeyAbsent(
251 key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
252 &may_modify_variables, &hlo_metadatas,
253 [&](TpuProgramGroupInterface* tpu_program_group) {
254 VLOG(1) << "Cloud TPU: Compiling TPU program";
255 // When this compile function is invoked, we know that host-memory
256 // cache TpuCompilationCache saw a cache miss. There are two codepaths:
257 // 1. If persistent cache is disabled, compile locally and populate
258 // host-memory cache.
259 // 2. If persistent cache is enabled, we do an additional lookup on
260 // the persistent cache.
261 // - If persistent cache also sees a cache miss, trigger
262 // compilation. Then, populate both persistent cache and
263 // host-memory cache.
264 // - If persistent cache sees a cache hit, retrieve cache entry from
265 // persistent cache to populate host-memory cache without
266 // recompilation. If retrieval failed, compile locally as a
267 // fallback and use the local compilation result to populate
268 // host-memory cache.
269 if (persistent_cache_ == nullptr) {
270 VLOG(1) << "Persistent compilation cache not enabled. Compiling "
271 "TPU executable locally and populating host-memory cache.";
272 return CompileLocallyAndFillHostCache(
273 ctx->function_library(), ctx->session_metadata(), mesh_state,
274 dynamic_shapes, guaranteed_constants, key, tpu_program_group);
275 }
276 return LookupPersistentCompilationCacheAndFillCaches(
277 ctx->function_library(), ctx->session_metadata(), mesh_state,
278 dynamic_shapes, guaranteed_constants, persistent_cache_.get(), key,
279 tpu_program_group);
280 });
281
282 // `ref_holder` is provided to CompileIfKeyAbsent to ensure that cache
283 // entry does not get evicted before TpuExecuteOp runs it and discards
284 // `ref_holder`. When TpuCompilationCacheEntryUnloader get destroyed in the
285 // event that user closes the session while there are in-flight program
286 // executions, it will discard the cache's reference to the cache entry
287 // and but not removed the entry until `ref_holder` discards the last
288 // reference to the entry. This ensures that the guarantees of
289 // `ref_holder` is not violated when this flag is true.
290 if (unload_cache_entry_on_session_close_) {
291 // Place `unloader` in TPU_SYSTEM device resource manager. Note that
292 // - TPUConfigResourceMgr returned by GetTPUConfigResourceMgr() is a special
293 // process-global ResourceMgr. There is only one TPUConfigResourceMgr, and
294 // it is never destroyed.
295 // - TPU_SYSTEM device resource manager is a normal device ResourceMgr for
296 // TPU_SYSTEM device. If DirectSession or isolate_session_state are used,
297 // there's one TPU_SYSTEM ResourceMgr for each session, and the
298 // ResourceMgrs will be destroyed when their corresponding session is
299 // closed. Otherwise there's one TPU_SYSTEM ResourceMgr that's only
300 // destroyed when the master-session is destroyed, not when the worker
301 // sessions are destroyed
302 TpuCompilationCacheEntryUnloader* unloader;
303 TF_RETURN_IF_ERROR(
304 ctx->resource_manager()
305 ->LookupOrCreate<TpuCompilationCacheEntryUnloader>(
306 ctx->resource_manager()->default_container(),
307 kCompilationCacheUnloaderResourceName, &unloader,
308 [cache](TpuCompilationCacheEntryUnloader** new_unloader) {
309 *new_unloader = new TpuCompilationCacheEntryUnloader(cache);
310 return Status::OK();
311 }));
312 // Note that LookupOrCreate puts two refcounts on unloader.
313 core::ScopedUnref unloader_unref(unloader);
314 unloader->AddCacheEntryUid(uid);
315 }
316
317 int64_t num_cores_with_compiled_programs = proto_key.size();
318 if (proto_key.size() == 1) {
319 // SPMD produces 1 program for all cores.
320 num_cores_with_compiled_programs = metadata_.num_cores_per_replica();
321 }
322 if (status.ok() &&
323 num_cores_with_compiled_programs +
324 (may_modify_variables.size() * static_cast<int>(!use_mlir_)) !=
325 ctx->num_outputs() - 1) {
326 status = errors::Internal(
327 "Number of cores with compiled programs (",
328 num_cores_with_compiled_programs, ") + variable states (",
329 may_modify_variables.size() * static_cast<int>(!use_mlir_),
330 ") + compilation status output != number of compile op outputs (",
331 ctx->num_outputs(), ")");
332 }
333
334 // TODO(jpienaar): status is not just due to the compilation. At this
335 // point we should be failing the execution of the op in some cases and
336 // returning a compilation error in others. For now, uniformly return an
337 // error and fail in _TPUExecute if status failed here.
338
339 // TODO(misard) the frame id will be wrong if this is ever called from
340 // within a function. Consider whether to use the same hack as is
341 // present in the rendezvous manager where the function call frame is
342 // cast to a uint64, or do something better all around.
343 std::string rendezvous_key_base = strings::StrCat(
344 "host_compute_rendezvous:", ctx->op_kernel().name(), ":",
345 ctx->frame_iter().frame_id, ":", ctx->frame_iter().iter_id, ":");
346
347 // Return compilation status.
348 {
349 Tensor output(DT_STRING, TensorShape({}));
350 tpu::CompilationResultProto proto;
351 proto.set_status_code(status.code());
352 if (!status.ok()) {
353 proto.set_status_error_message(
354 absl::StrCat("Compilation failure: ", status.error_message()));
355 }
356 if (return_hlo_protos_) {
357 // Return the HloProtos as part of compilation status.
358 for (const xla::HloProto* hlo_metadata : hlo_metadatas) {
359 xla::HloProto* hlo_proto = proto.add_hlo_protos();
360 *hlo_proto = *hlo_metadata;
361 }
362 }
363 SerializeToTString(proto, &output.scalar<tstring>()());
364 ctx->set_output(0, output);
365 status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey,
366 output.scalar<tstring>()());
367 }
368
369 if (status.ok()) {
370 for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
371 Tensor output(DT_STRING, TensorShape({3}));
372 if (proto_key.size() == 1) {
373 output.vec<tstring>()(0) = proto_key[0];
374 } else {
375 output.vec<tstring>()(0) = proto_key[i];
376 }
377 output.vec<tstring>()(1) = rendezvous_key_base;
378 if (sharding_key.empty()) {
379 output.vec<tstring>()(2) = "";
380 } else if (sharding_key.size() == 1) {
381 output.vec<tstring>()(2) = sharding_key[0];
382 } else {
383 TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
384 output.vec<tstring>()(2) = sharding_key[i];
385 }
386 ctx->set_output(i + 1, output);
387 }
388 if (!use_mlir_) {
389 // If any of the programs may modify a variable, then return that all
390 // do as the only current state being tracked here is if a model is
391 // read-only or not.
392 bool may_modify = false;
393 for (bool m : may_modify_variables) {
394 may_modify = may_modify || m;
395 }
396 for (int i = 0; i < may_modify_variables.size(); ++i) {
397 Tensor output(DT_BOOL, TensorShape({}));
398 output.scalar<bool>()() = may_modify;
399 ctx->set_output(i + num_cores_with_compiled_programs + 1, output);
400 }
401 }
402 VLOG(1) << "Cloud TPU: Compilation succeeded";
403 } else {
404 // Return error in the invalid case.
405 for (int i = 0; i < num_computations_; ++i) {
406 Tensor output(DT_STRING, TensorShape({3}));
407 output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
408 output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
409 output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
410 ctx->set_output(i + 1, output);
411 }
412 if (!use_mlir_) {
413 // The TPUCompileMLIR op does not have MayModifyVariable output
414 for (int i = 0; i < num_computations_; ++i) {
415 Tensor output(false);
416 ctx->set_output(i + num_computations_ + 1, output);
417 }
418 }
419 }
420 return status;
421 }
422
RegisterXLAFingerprints(const std::vector<TensorShape> & arg_shapes,TpuProgramGroupInterface * tpu_program_group,uint64 fingerprint)423 Status TpuCompileOpKernelCommon::RegisterXLAFingerprints(
424 const std::vector<TensorShape>& arg_shapes,
425 TpuProgramGroupInterface* tpu_program_group, uint64 fingerprint) {
426 // TODO(chiachenc): Support only one program for now.
427 if (tpu_program_group->program_count() != 1) {
428 LOG(INFO) << "Found " << tpu_program_group->program_count()
429 << " programs. Skip fingerprint registration.";
430 } else {
431 ResourceMgr* rm = GetTPUConfigResourceMgr();
432 tpu::TpuFingerprintLookup* fingerprint_lookup;
433 TF_RETURN_IF_ERROR(rm->LookupOrCreate<tpu::TpuFingerprintLookup>(
434 rm->default_container(), tpu::kFingerprintLookupResourceName,
435 &fingerprint_lookup, [&](tpu::TpuFingerprintLookup** new_lookup) {
436 *new_lookup = tpu::TpuFingerprintLookup::Create();
437 return Status::OK();
438 }));
439 uint64 tf_fingerprint =
440 tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes);
441 std::string xla_fingerprint = tpu_program_group->fingerprint(0);
442 VLOG(1) << "Registering TF fingerprint: " << tf_fingerprint
443 << " with XLA fingerprint: " << xla_fingerprint;
444 fingerprint_lookup->RegisterIntermediateAndValuePair(
445 tf_fingerprint, std::move(xla_fingerprint));
446 }
447 return Status::OK();
448 }
449 } // namespace tpu
450 } // namespace tensorflow
451