• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes for compiling XLA computations and managing handles that refer to
17 // them.
18 
19 #include <string>
20 #include <vector>
21 
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/compile_only_client.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/compiler/xrt/xrt.pb.h"
34 #include "tensorflow/compiler/xrt/xrt_metrics.h"
35 #include "tensorflow/compiler/xrt/xrt_util.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/resource_mgr.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/refcount.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/cleanup.h"
44 #include "tensorflow/core/lib/monitoring/timed.h"
45 #include "tensorflow/core/lib/strings/proto_serialization.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/platform/casts.h"
48 #include "tensorflow/core/platform/types.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
51 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
52 #include "tensorflow/core/tpu/kernels/tpu_compile_op.h"
53 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
54 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
55 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
56 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
57 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
58 #include "tensorflow/core/tpu/tpu_api.h"
59 #include "tensorflow/core/tpu/tpu_configuration.h"
60 #include "tensorflow/core/tpu/tpu_defs.h"
61 #include "tensorflow/stream_executor/stream_executor.h"
62 
63 namespace tensorflow {
64 
65 class XRTCompileOp : public OpKernel {
66  public:
67   explicit XRTCompileOp(OpKernelConstruction* ctx);
68   ~XRTCompileOp() override;
69   XRTCompileOp(const XRTCompileOp&) = delete;
70   XRTCompileOp& operator=(const XRTCompileOp&) = delete;
71 
72   void Compute(OpKernelContext* ctx) override;
73 
74  private:
75   Status Compile(const XLA_TpuMeshState* xla_mesh_state,
76                  const xrt::XLAComputation& computation_proto,
77                  tensorflow::tpu::TpuProgramGroupInterface* tpu_program_group);
78 };
79 
XRTCompileOp(OpKernelConstruction * ctx)80 XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
81 
Compile(const XLA_TpuMeshState * xla_mesh_state,const xrt::XLAComputation & computation_proto,tensorflow::tpu::TpuProgramGroupInterface * tpu_program_group)82 Status XRTCompileOp::Compile(
83     const XLA_TpuMeshState* xla_mesh_state,
84     const xrt::XLAComputation& computation_proto,
85     tensorflow::tpu::TpuProgramGroupInterface* tpu_program_group) {
86   return tensorflow::tpu::TpuProgramGroup::CompileAndBuild(
87       computation_proto, xla_mesh_state, tpu_program_group);
88 }
89 
CompilationCacheKey(const xrt::XLAComputation & computation,tensorflow::tpu::TpuMeshStateInterface * mesh_state,int num_replicas,int num_cores_per_replica)90 tpu::TpuCompilationCacheKey CompilationCacheKey(
91     const xrt::XLAComputation& computation,
92     tensorflow::tpu::TpuMeshStateInterface* mesh_state, int num_replicas,
93     int num_cores_per_replica) {
94   string computation_serialized;
95   CHECK(SerializeToStringDeterministic(computation, &computation_serialized));
96   tpu::TPUCompileMetadataProto metadata;
97   metadata.set_num_replicas(num_replicas);
98   metadata.set_num_cores_per_replica(num_cores_per_replica);
99   const tpu::TpuCompilationCacheKey key = CreateCompilationCacheKey(
100       "compile", 0, tensorflow::Fingerprint64(computation_serialized), {}, {},
101       metadata, *mesh_state);
102   return key;
103 }
104 
ExitCountdown(Env * env,std::shared_ptr<std::atomic<bool>> done)105 void ExitCountdown(Env* env, std::shared_ptr<std::atomic<bool>> done) {
106   const int kSleepSeconds = 300;
107   LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
108             << " seconds to give time for TPUCompileOp to finished.";
109   env->SleepForMicroseconds(kSleepSeconds * 1000000);
110   if (done->load()) {
111     // If the TpuCompileOp has finished, then terminate peacefully.
112     return;
113   }
114 
115   LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
116              << "termination is to ensure a consistent state.";
117   std::exit(42);
118 }
119 
Compute(OpKernelContext * ctx)120 void XRTCompileOp::Compute(OpKernelContext* ctx) {
121   VLOG(1) << "XRTCompileOp::Compute";
122   auto timed = monitoring::MakeTimed(xrt_metrics::GetCompileCell());
123 
124   std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
125   CancellationToken token =
126       ctx->cancellation_manager()->get_cancellation_token();
127   const bool already_cancelled =
128       !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
129         if (tpu::OpsApiFn()
130                 ->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
131           return;
132         }
133 
134         // Sleep and exit in another thread so the cancellation manager can
135         // continue running callbacks.
136         Env* env = ctx->env();
137         env->SchedClosure([env, done]() { ExitCountdown(env, done); });
138       });
139 
140   // If the RPC was cancelled before we registered the cancellation callback,
141   // don't compile the TPU program.
142   OP_REQUIRES(ctx, !already_cancelled,
143               errors::Cancelled("RPC cancelled, not compiling TPU program"));
144 
145   // We only want to abort the process if a cancellation actually occurs during
146   // compilation; we must deregister the callback in the success case. It
147   // doesn't hurt to also deregister the callback in the failure case; the
148   // CancellationManager ensures that already-registered callbacks will be run
149   // once cancellation has started.
150   auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
151     ctx->cancellation_manager()->DeregisterCallback(token);
152     done->store(true);
153   });
154 
155   VLOG(1) << "Retrieving pod state";
156   // Retrieve the topology from the resource manager
157   ResourceMgr* rm = GetTPUConfigResourceMgr();
158   tensorflow::tpu::TpuMeshStateInterface* mesh_state;
159   OP_REQUIRES_OK(ctx,
160                  rm->Lookup(rm->default_container(),
161                             tensorflow::tpu::kTpuMeshStateInterfaceResourceName,
162                             &mesh_state));
163   core::ScopedUnref mesh_state_unref(mesh_state);
164 
165   const Tensor& computation_input = ctx->input(0);
166   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
167               errors::Internal("computation input should be a string scalar"));
168 
169   xrt::XLAComputation computation_proto;
170   OP_REQUIRES(
171       ctx,
172       computation_proto.ParseFromString(computation_input.scalar<tstring>()()),
173       errors::InvalidArgument(
174           "Unable to parse computation input to XLAComputation"));
175 
176   const xrt::XLAComputationConfig& config = computation_proto.config();
177   int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
178   CHECK_GT(num_replicas, 0);
179   int num_cores_per_replica =
180       config.num_cores_per_replica() ? config.num_cores_per_replica() : 1;
181 
182   const tpu::TpuCompilationCacheKey key = CompilationCacheKey(
183       computation_proto, mesh_state, num_replicas, num_cores_per_replica);
184 
185   // Process-wide cache of Tpu executables.
186   tpu::TpuCompilationCacheInterface* cache;
187   OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
188                           rm->default_container(),
189                           tpu::kCompilationCacheResourceName, &cache));
190   core::ScopedUnref cache_unref(cache);
191 
192   int64_t uid;
193   std::vector<string> proto_key;
194   std::vector<string> shard_key;
195   std::vector<bool> may_modify_variables;
196   absl::Span<const xla::HloProto* const> hlo_metadata;
197   OP_REQUIRES_OK(
198       ctx, cache->CompileIfKeyAbsent(
199                key, /*session_metadata=*/nullptr,
200                /*per_step_ref_holder=*/nullptr, &uid, &proto_key, &shard_key,
201                &may_modify_variables, &hlo_metadata,
202                [&](tpu::TpuProgramGroupInterface* tpu_program_group) {
203                  VLOG(1) << "Compiling TPU executable";
204                  return Compile(mesh_state->data(), computation_proto,
205                                 tpu_program_group);
206                }));
207 
208   Tensor output(DT_INT64, TensorShape({}));
209   output.scalar<int64>()() = uid;
210   ctx->set_output(0, output);
211 
212   Tensor program_shape_output(DT_STRING, TensorShape({num_cores_per_replica}));
213   for (int64_t i = 0; i < num_cores_per_replica; ++i) {
214     xla::ProgramShapeProto program_shape =
215         hlo_metadata[i]->hlo_module().host_program_shape();
216     program_shape_output.vec<tstring>()(i) = program_shape.SerializeAsString();
217   }
218   ctx->set_output(1, program_shape_output);
219 }
220 
221 XRTCompileOp::~XRTCompileOp() = default;
222 
223 class XRTReleaseCompilationRefOp : public OpKernel {
224  public:
225   explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
226   ~XRTReleaseCompilationRefOp() override;
227   XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
228   XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
229       delete;
230 
231   void Compute(OpKernelContext* ctx) override;
232 };
233 
XRTReleaseCompilationRefOp(OpKernelConstruction * ctx)234 XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
235     OpKernelConstruction* ctx)
236     : OpKernel(ctx) {}
237 
238 XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
239 
Compute(OpKernelContext * ctx)240 void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
241   VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
242   auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
243   ResourceMgr* rm = GetTPUConfigResourceMgr();
244   OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
245 
246   // Process-wide cache of Tpu executables.
247   tpu::TpuCompilationCacheInterface* cache;
248   OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
249                           rm->default_container(),
250                           tpu::kCompilationCacheResourceName, &cache));
251   core::ScopedUnref cache_unref(cache);
252 
253   const Tensor& keys_tensor = ctx->input(0);
254   auto flat_keys = keys_tensor.flat<int64>();
255   for (int64_t i = 0; i < flat_keys.size(); ++i) {
256     int64_t key = flat_keys(i);
257     OP_REQUIRES_OK(ctx, cache->Release(key));
258     VLOG(2) << "Released computation handle " << key;
259   }
260 }
261 
262 REGISTER_KERNEL_BUILDER(Name("XRTCompile")
263                             .Device(DEVICE_TPU_NODE)
264                             .HostMemory("computation")
265                             .HostMemory("handle"),
266                         XRTCompileOp);
267 
268 REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
269                             .Device(DEVICE_TPU_NODE)
270                             .HostMemory("handle"),
271                         XRTReleaseCompilationRefOp);
272 
273 }  // namespace tensorflow
274