• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
16 
17 #include <string>
18 
19 #include "absl/strings/string_view.h"
20 #include "absl/types/optional.h"
21 #include "tensorflow/compiler/jit/flags.h"
22 #include "tensorflow/compiler/xla/client/client_library.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/core/common_runtime/graph_optimizer.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/metrics.h"
27 #include "tensorflow/core/framework/resource_mgr.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/platform/status.h"
30 #include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
31 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
32 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
33 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry_unloader.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
35 #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
36 #include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
37 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
38 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
39 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
40 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
41 #include "tensorflow/core/tpu/kernels/tpu_util.h"
42 #include "tensorflow/core/tpu/tpu_api.h"
43 #include "tensorflow/core/tpu/tpu_compile_interface.h"
44 #include "tensorflow/core/tpu/tpu_configuration.h"
45 #include "tensorflow/core/tpu/tpu_defs.h"
46 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
47 
48 namespace tensorflow {
49 namespace tpu {
50 
51 CompileOpImplFactory* CompileOpImplFactory::factory_ = nullptr;
52 
53 /* static */
Get()54 CompileOpImplFactory* CompileOpImplFactory::Get() { return factory_; }
55 
56 /* static */
Register(CompileOpImplFactory * factory)57 void CompileOpImplFactory::Register(CompileOpImplFactory* factory) {
58   CHECK_EQ(factory_, nullptr)
59       << "CompileOpImplFactory can only be registered "
60          "once and there can only be one factory active and used.";
61   factory_ = factory;
62 }
63 
ExitCountdown(Env * env,std::shared_ptr<std::atomic<bool>> done)64 /* static */ void TpuCompileOpKernelCommon::ExitCountdown(
65     Env* env, std::shared_ptr<std::atomic<bool>> done) {
66   const int kSleepSeconds = 300;
67   LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
68             << " seconds to give time for TPUCompileOp to finished.";
69   env->SleepForMicroseconds(kSleepSeconds * 1000000);
70   if (done->load()) {
71     // If the TpuCompileOp has finished, then terminate peacefully.
72     return;
73   }
74 
75   LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
76              << "termination is to ensure a consistent state.";
77   std::exit(42);
78 }
79 
GetDynamicShapes(OpKernelContext * ctx,std::vector<TensorShape> * shapes)80 /* static */ Status TpuCompileOpKernelCommon::GetDynamicShapes(
81     OpKernelContext* ctx, std::vector<TensorShape>* shapes) {
82   OpInputList dynamic_shapes;
83   TF_RETURN_IF_ERROR(ctx->input_list("dynamic_shapes", &dynamic_shapes));
84 
85   shapes->resize(dynamic_shapes.size());
86   for (int i = 0; i < dynamic_shapes.size(); ++i) {
87     TF_RETURN_IF_ERROR(
88         tpu::ShapeTensorToTensorShape(dynamic_shapes[i], &(*shapes)[i]));
89   }
90   return Status::OK();
91 }
92 
Compute(OpKernelContext * ctx)93 void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
94   VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
95 
96   std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
97 
98   CancellationToken token =
99       ctx->cancellation_manager()->get_cancellation_token();
100   const bool already_cancelled =
101       !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
102         if (OpsApiFn()->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
103           return;
104         }
105 
106         // Sleep and exit in another thread so the cancellation manager can
107         // continue running callbacks.
108         Env* env = ctx->env();
109         env->SchedClosure([env, done]() { ExitCountdown(env, done); });
110       });
111 
112   // If the RPC was cancelled before we registered the cancellation callback,
113   // don't compile the TPU program.
114   OP_REQUIRES(ctx, !already_cancelled,
115               errors::Cancelled("RPC cancelled, not compiling TPU program"));
116 
117   // We only want to abort the process if a cancellation actually occurs during
118   // compilation; we must deregister the callback in the success case. It
119   // doesn't hurt to also deregister the callback in the failure case; the
120   // CancellationManager ensures that already-registered callbacks will be run
121   // once cancellation has started.
122   auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
123     ctx->cancellation_manager()->DeregisterCallback(token);
124     done->store(true);
125   });
126 
127   Status compile_status = ComputeInternal(ctx);
128   string status_payload;
129   // Construct payload if compile_status is not ok and there's no payload for
130   // compilation yet.
131   if (!compile_status.ok() &&
132       compile_status.GetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey)
133           .empty()) {
134     tpu::CompilationResultProto proto;
135     proto.set_status_code(compile_status.code());
136     proto.set_status_error_message(compile_status.error_message());
137     status_payload = proto.SerializeAsString();
138   }
139   OP_REQUIRES_OK_OR_SET_PAYLOAD(ctx,
140                                 TpuCompileInterface::kTpuCompileErrorPayloadKey,
141                                 status_payload, compile_status);
142 }
143 
CompileLocallyAndFillHostCache(FunctionLibraryRuntime * flib_runtime,const SessionMetadata * session_metadata,const TpuMeshStateInterface * mesh_state,const std::vector<TensorShape> & dynamic_shapes,const OpInputList & guaranteed_constants,const TpuCompilationCacheKey & key,TpuProgramGroupInterface * tpu_program_group)144 Status TpuCompileOpKernelCommon::CompileLocallyAndFillHostCache(
145     FunctionLibraryRuntime* flib_runtime,
146     const SessionMetadata* session_metadata,
147     const TpuMeshStateInterface* mesh_state,
148     const std::vector<TensorShape>& dynamic_shapes,
149     const OpInputList& guaranteed_constants, const TpuCompilationCacheKey& key,
150     TpuProgramGroupInterface* tpu_program_group) {
151   absl::Time start_time = absl::Now();
152   std::vector<TensorShape> arg_shapes;
153   TF_RETURN_IF_ERROR(
154       ComputeArgumentShapes(metadata_, dynamic_shapes, &arg_shapes));
155   Status compile_status;
156   if (use_mlir_) {
157     const ConfigProto* config = flib_runtime->config_proto();
158     ConfigProto::Experimental::MlirBridgeRollout rollout_state =
159         GetMlirBridgeRolloutState(config ? absl::make_optional(*config)
160                                          : absl::nullopt);
161     compile_status = Compile(MlirToHloArgs{mlir_module_, rollout_state},
162                              mesh_state->data(), arg_shapes, tpu_program_group);
163   } else {
164     compile_status =
165         Compile(FunctionToHloArgs{&function_,
166                                   flib_runtime->GetFunctionLibraryDefinition(),
167                                   flib_runtime->graph_def_version(),
168                                   {&guaranteed_constants}},
169                 mesh_state->data(), arg_shapes, tpu_program_group);
170   }
171 
172   absl::Time end_time = absl::Now();
173   auto duration = end_time - start_time;
174 
175   const std::string session_name = SessionNameFromMetadata(session_metadata);
176   LOG(INFO) << "Compilation of " << key.prefix << " with session name "
177             << session_name << " took " << duration << " and "
178             << (compile_status.ok() ? "succeeded" : "failed");
179   tpu_program_group->LogProgramMemorySummary();
180   metrics::UpdateXlaCompilationTime(absl::ToInt64Microseconds(duration));
181   TpuCompilationMetrics::IncrementCompilationCount(session_name);
182 
183   TF_RETURN_IF_ERROR(tpu_program_group->LogCompilationStats(key, duration));
184 
185   return compile_status;
186 }
187 
ComputeInternal(OpKernelContext * ctx)188 Status TpuCompileOpKernelCommon::ComputeInternal(OpKernelContext* ctx) {
189   VLOG(1) << "Retrieving mesh state";
190   // Retrieve the topology from the resource manager
191   ResourceMgr* rm = GetTPUConfigResourceMgr();
192 
193   TpuMeshStateInterface* mesh_state;
194   TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(),
195                                 kTpuMeshStateInterfaceResourceName,
196                                 &mesh_state));
197   core::ScopedUnref mesh_state_unref(mesh_state);
198 
199   std::vector<TensorShape> dynamic_shapes;
200   TF_RETURN_IF_ERROR(GetDynamicShapes(ctx, &dynamic_shapes));
201 
202   OpInputList guaranteed_constants;
203   // TODO(ycao): Decide whether/how to support guaranteed constants in
204   // MLIR-based TF-Compiler Bridge.
205   if (!use_mlir_) {
206     TF_RETURN_IF_ERROR(
207         ctx->input_list("guaranteed_constants", &guaranteed_constants));
208   }
209 
210   const TpuCompilationCacheKey key = CreateCompilationCacheKey(
211       function_.name(), metadata_.function_library_fingerprint(),
212       mlir_module_fingerprint_, guaranteed_constants, dynamic_shapes, metadata_,
213       *mesh_state);
214 
215   // Process-wide cache of TPU executables.
216   TpuCompilationCacheInterface* cache;
217   TF_RETURN_IF_ERROR(rm->Lookup<TpuCompilationCacheInterface>(
218       rm->default_container(), kCompilationCacheResourceName, &cache));
219   core::ScopedUnref cache_unref(cache);
220 
221   // Per-step object that ensures that compilation cache entries aren't
222   // evicted until the step completes. This mechanism ensures that the
223   // downstream TPUExecute Ops in this step will be able to look up the
224   // compiled executable even if it is marked for eviction before the step
225   // ends.
226   //
227   // We can't use GetTPUConfigResourceMgr here because it may return the
228   // global ResourceMgr, which is not associated with any device, and
229   // GraphMgr's ScopedStepContainer only searches ResourceMgrs associated
230   // with devices when deleting resources at step boundaries.
231   CompilationRefHolder* ref_holder;
232   if (ctx->step_container() == nullptr) {
233     return errors::FailedPrecondition(
234         "TPUCompileOp requires a step container.");
235   }
236   TF_RETURN_IF_ERROR(
237       ctx->step_container()->LookupOrCreate<CompilationRefHolder>(
238           ctx->resource_manager(), "ref_holder", &ref_holder,
239           [cache](CompilationRefHolder** h) {
240             *h = cache->MakePerStepRefHolder();
241             return Status::OK();
242           }));
243   core::ScopedUnref ref_holder_unref(ref_holder);
244 
245   int64_t uid;
246   std::vector<std::string> proto_key;
247   std::vector<std::string> sharding_key;
248   std::vector<bool> may_modify_variables;
249   absl::Span<const xla::HloProto* const> hlo_metadatas;
250   Status status = cache->CompileIfKeyAbsent(
251       key, ctx->session_metadata(), ref_holder, &uid, &proto_key, &sharding_key,
252       &may_modify_variables, &hlo_metadatas,
253       [&](TpuProgramGroupInterface* tpu_program_group) {
254         VLOG(1) << "Cloud TPU: Compiling TPU program";
255         // When this compile function is invoked, we know that host-memory
256         // cache TpuCompilationCache saw a cache miss. There are two codepaths:
257         // 1. If persistent cache is disabled, compile locally and populate
258         //    host-memory cache.
259         // 2. If persistent cache is enabled, we do an additional lookup on
260         //    the persistent cache.
261         //    - If persistent cache also sees a cache miss, trigger
262         //      compilation. Then, populate both persistent cache and
263         //      host-memory cache.
264         //    - If persistent cache sees a cache hit, retrieve cache entry from
265         //      persistent cache to populate host-memory cache without
266         //      recompilation. If retrieval failed, compile locally as a
267         //      fallback and use the local compilation result to populate
268         //      host-memory cache.
269         if (persistent_cache_ == nullptr) {
270           VLOG(1) << "Persistent compilation cache not enabled. Compiling "
271                      "TPU executable locally and populating host-memory cache.";
272           return CompileLocallyAndFillHostCache(
273               ctx->function_library(), ctx->session_metadata(), mesh_state,
274               dynamic_shapes, guaranteed_constants, key, tpu_program_group);
275         }
276         return LookupPersistentCompilationCacheAndFillCaches(
277             ctx->function_library(), ctx->session_metadata(), mesh_state,
278             dynamic_shapes, guaranteed_constants, persistent_cache_.get(), key,
279             tpu_program_group);
280       });
281 
282   // `ref_holder` is provided to CompileIfKeyAbsent to ensure that cache
283   // entry does not get evicted before TpuExecuteOp runs it and discards
284   // `ref_holder`. When TpuCompilationCacheEntryUnloader get destroyed in the
285   // event that user closes the session while there are in-flight program
286   // executions, it will discard the cache's reference to the cache entry
287   // and but not removed the entry until `ref_holder` discards the last
288   // reference to the entry. This ensures that the guarantees of
289   // `ref_holder` is not violated when this flag is true.
290   if (unload_cache_entry_on_session_close_) {
291     // Place `unloader` in TPU_SYSTEM device resource manager. Note that
292     // - TPUConfigResourceMgr returned by GetTPUConfigResourceMgr() is a special
293     //   process-global ResourceMgr. There is only one TPUConfigResourceMgr, and
294     //   it is never destroyed.
295     // - TPU_SYSTEM device resource manager is a normal device ResourceMgr for
296     //   TPU_SYSTEM device. If DirectSession or isolate_session_state are used,
297     //   there's one TPU_SYSTEM ResourceMgr for each session, and the
298     //   ResourceMgrs will be destroyed when their corresponding session is
299     //   closed. Otherwise there's one TPU_SYSTEM ResourceMgr that's only
300     //   destroyed when the master-session is destroyed, not when the worker
301     //   sessions are destroyed
302     TpuCompilationCacheEntryUnloader* unloader;
303     TF_RETURN_IF_ERROR(
304         ctx->resource_manager()
305             ->LookupOrCreate<TpuCompilationCacheEntryUnloader>(
306                 ctx->resource_manager()->default_container(),
307                 kCompilationCacheUnloaderResourceName, &unloader,
308                 [cache](TpuCompilationCacheEntryUnloader** new_unloader) {
309                   *new_unloader = new TpuCompilationCacheEntryUnloader(cache);
310                   return Status::OK();
311                 }));
312     // Note that LookupOrCreate puts two refcounts on unloader.
313     core::ScopedUnref unloader_unref(unloader);
314     unloader->AddCacheEntryUid(uid);
315   }
316 
317   int64_t num_cores_with_compiled_programs = proto_key.size();
318   if (proto_key.size() == 1) {
319     // SPMD produces 1 program for all cores.
320     num_cores_with_compiled_programs = metadata_.num_cores_per_replica();
321   }
322   if (status.ok() &&
323       num_cores_with_compiled_programs +
324               (may_modify_variables.size() * static_cast<int>(!use_mlir_)) !=
325           ctx->num_outputs() - 1) {
326     status = errors::Internal(
327         "Number of cores with compiled programs (",
328         num_cores_with_compiled_programs, ") + variable states (",
329         may_modify_variables.size() * static_cast<int>(!use_mlir_),
330         ") + compilation status output != number of compile op outputs (",
331         ctx->num_outputs(), ")");
332   }
333 
334   // TODO(jpienaar): status is not just due to the compilation. At this
335   // point we should be failing the execution of the op in some cases and
336   // returning a compilation error in others. For now, uniformly return an
337   // error and fail in _TPUExecute if status failed here.
338 
339   // TODO(misard) the frame id will be wrong if this is ever called from
340   // within a function. Consider whether to use the same hack as is
341   // present in the rendezvous manager where the function call frame is
342   // cast to a uint64, or do something better all around.
343   std::string rendezvous_key_base = strings::StrCat(
344       "host_compute_rendezvous:", ctx->op_kernel().name(), ":",
345       ctx->frame_iter().frame_id, ":", ctx->frame_iter().iter_id, ":");
346 
347   // Return compilation status.
348   {
349     Tensor output(DT_STRING, TensorShape({}));
350     tpu::CompilationResultProto proto;
351     proto.set_status_code(status.code());
352     if (!status.ok()) {
353       proto.set_status_error_message(
354           absl::StrCat("Compilation failure: ", status.error_message()));
355     }
356     if (return_hlo_protos_) {
357       // Return the HloProtos as part of compilation status.
358       for (const xla::HloProto* hlo_metadata : hlo_metadatas) {
359         xla::HloProto* hlo_proto = proto.add_hlo_protos();
360         *hlo_proto = *hlo_metadata;
361       }
362     }
363     SerializeToTString(proto, &output.scalar<tstring>()());
364     ctx->set_output(0, output);
365     status.SetPayload(TpuCompileInterface::kTpuCompileErrorPayloadKey,
366                       output.scalar<tstring>()());
367   }
368 
369   if (status.ok()) {
370     for (int i = 0; i < num_cores_with_compiled_programs; ++i) {
371       Tensor output(DT_STRING, TensorShape({3}));
372       if (proto_key.size() == 1) {
373         output.vec<tstring>()(0) = proto_key[0];
374       } else {
375         output.vec<tstring>()(0) = proto_key[i];
376       }
377       output.vec<tstring>()(1) = rendezvous_key_base;
378       if (sharding_key.empty()) {
379         output.vec<tstring>()(2) = "";
380       } else if (sharding_key.size() == 1) {
381         output.vec<tstring>()(2) = sharding_key[0];
382       } else {
383         TF_RET_CHECK(sharding_key.size() == num_cores_with_compiled_programs);
384         output.vec<tstring>()(2) = sharding_key[i];
385       }
386       ctx->set_output(i + 1, output);
387     }
388     if (!use_mlir_) {
389       // If any of the programs may modify a variable, then return that all
390       // do as the only current state being tracked here is if a model is
391       // read-only or not.
392       bool may_modify = false;
393       for (bool m : may_modify_variables) {
394         may_modify = may_modify || m;
395       }
396       for (int i = 0; i < may_modify_variables.size(); ++i) {
397         Tensor output(DT_BOOL, TensorShape({}));
398         output.scalar<bool>()() = may_modify;
399         ctx->set_output(i + num_cores_with_compiled_programs + 1, output);
400       }
401     }
402     VLOG(1) << "Cloud TPU: Compilation succeeded";
403   } else {
404     // Return error in the invalid case.
405     for (int i = 0; i < num_computations_; ++i) {
406       Tensor output(DT_STRING, TensorShape({3}));
407       output.vec<tstring>()(0) = "<<NO PROGRAM AS COMPILATION FAILED>>";
408       output.vec<tstring>()(1) = "<<NO RENDEZVOUS KEY AS COMPILATION FAILED>>";
409       output.vec<tstring>()(2) = "<<NO SHARDing KEY AS COMPILATION FAILED>>";
410       ctx->set_output(i + 1, output);
411     }
412     if (!use_mlir_) {
413       // The TPUCompileMLIR op does not have MayModifyVariable output
414       for (int i = 0; i < num_computations_; ++i) {
415         Tensor output(false);
416         ctx->set_output(i + num_computations_ + 1, output);
417       }
418     }
419   }
420   return status;
421 }
422 
RegisterXLAFingerprints(const std::vector<TensorShape> & arg_shapes,TpuProgramGroupInterface * tpu_program_group,uint64 fingerprint)423 Status TpuCompileOpKernelCommon::RegisterXLAFingerprints(
424     const std::vector<TensorShape>& arg_shapes,
425     TpuProgramGroupInterface* tpu_program_group, uint64 fingerprint) {
426   // TODO(chiachenc): Support only one program for now.
427   if (tpu_program_group->program_count() != 1) {
428     LOG(INFO) << "Found " << tpu_program_group->program_count()
429               << " programs. Skip fingerprint registration.";
430   } else {
431     ResourceMgr* rm = GetTPUConfigResourceMgr();
432     tpu::TpuFingerprintLookup* fingerprint_lookup;
433     TF_RETURN_IF_ERROR(rm->LookupOrCreate<tpu::TpuFingerprintLookup>(
434         rm->default_container(), tpu::kFingerprintLookupResourceName,
435         &fingerprint_lookup, [&](tpu::TpuFingerprintLookup** new_lookup) {
436           *new_lookup = tpu::TpuFingerprintLookup::Create();
437           return Status::OK();
438         }));
439     uint64 tf_fingerprint =
440         tpu::CreateFingerprintWithNameAndShapes(fingerprint, arg_shapes);
441     std::string xla_fingerprint = tpu_program_group->fingerprint(0);
442     VLOG(1) << "Registering TF fingerprint: " << tf_fingerprint
443             << " with XLA fingerprint: " << xla_fingerprint;
444     fingerprint_lookup->RegisterIntermediateAndValuePair(
445         tf_fingerprint, std::move(xla_fingerprint));
446   }
447   return Status::OK();
448 }
449 }  // namespace tpu
450 }  // namespace tensorflow
451