1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Classes for compiling XLA computations and managing handles that refer to
17 // them.
18
19 #include <string>
20 #include <vector>
21
22 #include "tensorflow/compiler/tf2xla/shape_util.h"
23 #include "tensorflow/compiler/xla/client/client_library.h"
24 #include "tensorflow/compiler/xla/client/compile_only_client.h"
25 #include "tensorflow/compiler/xla/client/xla_computation.h"
26 #include "tensorflow/compiler/xla/debug_options_flags.h"
27 #include "tensorflow/compiler/xla/service/compiler.h"
28 #include "tensorflow/compiler/xla/service/dump.h"
29 #include "tensorflow/compiler/xla/service/hlo.pb.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/compiler/xrt/xrt.pb.h"
34 #include "tensorflow/compiler/xrt/xrt_metrics.h"
35 #include "tensorflow/compiler/xrt/xrt_util.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/resource_mgr.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/tensor_shape.h"
40 #include "tensorflow/core/framework/types.pb.h"
41 #include "tensorflow/core/lib/core/refcount.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/gtl/cleanup.h"
44 #include "tensorflow/core/lib/monitoring/timed.h"
45 #include "tensorflow/core/lib/strings/proto_serialization.h"
46 #include "tensorflow/core/lib/strings/strcat.h"
47 #include "tensorflow/core/platform/casts.h"
48 #include "tensorflow/core/platform/types.h"
49 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
50 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
51 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
52 #include "tensorflow/core/tpu/kernels/tpu_compile_op.h"
53 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
54 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
55 #include "tensorflow/core/tpu/kernels/tpu_op_util.h"
56 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
57 #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
58 #include "tensorflow/core/tpu/tpu_api.h"
59 #include "tensorflow/core/tpu/tpu_configuration.h"
60 #include "tensorflow/core/tpu/tpu_defs.h"
61 #include "tensorflow/stream_executor/stream_executor.h"
62
63 namespace tensorflow {
64
65 class XRTCompileOp : public OpKernel {
66 public:
67 explicit XRTCompileOp(OpKernelConstruction* ctx);
68 ~XRTCompileOp() override;
69 XRTCompileOp(const XRTCompileOp&) = delete;
70 XRTCompileOp& operator=(const XRTCompileOp&) = delete;
71
72 void Compute(OpKernelContext* ctx) override;
73
74 private:
75 Status Compile(const XLA_TpuMeshState* xla_mesh_state,
76 const xrt::XLAComputation& computation_proto,
77 tensorflow::tpu::TpuProgramGroupInterface* tpu_program_group);
78 };
79
XRTCompileOp(OpKernelConstruction * ctx)80 XRTCompileOp::XRTCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
81
Compile(const XLA_TpuMeshState * xla_mesh_state,const xrt::XLAComputation & computation_proto,tensorflow::tpu::TpuProgramGroupInterface * tpu_program_group)82 Status XRTCompileOp::Compile(
83 const XLA_TpuMeshState* xla_mesh_state,
84 const xrt::XLAComputation& computation_proto,
85 tensorflow::tpu::TpuProgramGroupInterface* tpu_program_group) {
86 return tensorflow::tpu::TpuProgramGroup::CompileAndBuild(
87 computation_proto, xla_mesh_state, tpu_program_group);
88 }
89
CompilationCacheKey(const xrt::XLAComputation & computation,tensorflow::tpu::TpuMeshStateInterface * mesh_state,int num_replicas,int num_cores_per_replica)90 tpu::TpuCompilationCacheKey CompilationCacheKey(
91 const xrt::XLAComputation& computation,
92 tensorflow::tpu::TpuMeshStateInterface* mesh_state, int num_replicas,
93 int num_cores_per_replica) {
94 string computation_serialized;
95 CHECK(SerializeToStringDeterministic(computation, &computation_serialized));
96 tpu::TPUCompileMetadataProto metadata;
97 metadata.set_num_replicas(num_replicas);
98 metadata.set_num_cores_per_replica(num_cores_per_replica);
99 const tpu::TpuCompilationCacheKey key = CreateCompilationCacheKey(
100 "compile", 0, tensorflow::Fingerprint64(computation_serialized), {}, {},
101 metadata, *mesh_state);
102 return key;
103 }
104
ExitCountdown(Env * env,std::shared_ptr<std::atomic<bool>> done)105 void ExitCountdown(Env* env, std::shared_ptr<std::atomic<bool>> done) {
106 const int kSleepSeconds = 300;
107 LOG(INFO) << "TpuCompileOp was cancelled. Sleeping for " << kSleepSeconds
108 << " seconds to give time for TPUCompileOp to finished.";
109 env->SleepForMicroseconds(kSleepSeconds * 1000000);
110 if (done->load()) {
111 // If the TpuCompileOp has finished, then terminate peacefully.
112 return;
113 }
114
115 LOG(ERROR) << "Aborting process due to cancelled TpuCompileOp. This "
116 << "termination is to ensure a consistent state.";
117 std::exit(42);
118 }
119
Compute(OpKernelContext * ctx)120 void XRTCompileOp::Compute(OpKernelContext* ctx) {
121 VLOG(1) << "XRTCompileOp::Compute";
122 auto timed = monitoring::MakeTimed(xrt_metrics::GetCompileCell());
123
124 std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
125 CancellationToken token =
126 ctx->cancellation_manager()->get_cancellation_token();
127 const bool already_cancelled =
128 !ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
129 if (tpu::OpsApiFn()
130 ->TpuCompile_ShouldTpuCompileOpIgnoreCancellationFn()) {
131 return;
132 }
133
134 // Sleep and exit in another thread so the cancellation manager can
135 // continue running callbacks.
136 Env* env = ctx->env();
137 env->SchedClosure([env, done]() { ExitCountdown(env, done); });
138 });
139
140 // If the RPC was cancelled before we registered the cancellation callback,
141 // don't compile the TPU program.
142 OP_REQUIRES(ctx, !already_cancelled,
143 errors::Cancelled("RPC cancelled, not compiling TPU program"));
144
145 // We only want to abort the process if a cancellation actually occurs during
146 // compilation; we must deregister the callback in the success case. It
147 // doesn't hurt to also deregister the callback in the failure case; the
148 // CancellationManager ensures that already-registered callbacks will be run
149 // once cancellation has started.
150 auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
151 ctx->cancellation_manager()->DeregisterCallback(token);
152 done->store(true);
153 });
154
155 VLOG(1) << "Retrieving pod state";
156 // Retrieve the topology from the resource manager
157 ResourceMgr* rm = GetTPUConfigResourceMgr();
158 tensorflow::tpu::TpuMeshStateInterface* mesh_state;
159 OP_REQUIRES_OK(ctx,
160 rm->Lookup(rm->default_container(),
161 tensorflow::tpu::kTpuMeshStateInterfaceResourceName,
162 &mesh_state));
163 core::ScopedUnref mesh_state_unref(mesh_state);
164
165 const Tensor& computation_input = ctx->input(0);
166 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
167 errors::Internal("computation input should be a string scalar"));
168
169 xrt::XLAComputation computation_proto;
170 OP_REQUIRES(
171 ctx,
172 computation_proto.ParseFromString(computation_input.scalar<tstring>()()),
173 errors::InvalidArgument(
174 "Unable to parse computation input to XLAComputation"));
175
176 const xrt::XLAComputationConfig& config = computation_proto.config();
177 int num_replicas = config.num_replicas() ? config.num_replicas() : 1;
178 CHECK_GT(num_replicas, 0);
179 int num_cores_per_replica =
180 config.num_cores_per_replica() ? config.num_cores_per_replica() : 1;
181
182 const tpu::TpuCompilationCacheKey key = CompilationCacheKey(
183 computation_proto, mesh_state, num_replicas, num_cores_per_replica);
184
185 // Process-wide cache of Tpu executables.
186 tpu::TpuCompilationCacheInterface* cache;
187 OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
188 rm->default_container(),
189 tpu::kCompilationCacheResourceName, &cache));
190 core::ScopedUnref cache_unref(cache);
191
192 int64_t uid;
193 std::vector<string> proto_key;
194 std::vector<string> shard_key;
195 std::vector<bool> may_modify_variables;
196 absl::Span<const xla::HloProto* const> hlo_metadata;
197 OP_REQUIRES_OK(
198 ctx, cache->CompileIfKeyAbsent(
199 key, /*session_metadata=*/nullptr,
200 /*per_step_ref_holder=*/nullptr, &uid, &proto_key, &shard_key,
201 &may_modify_variables, &hlo_metadata,
202 [&](tpu::TpuProgramGroupInterface* tpu_program_group) {
203 VLOG(1) << "Compiling TPU executable";
204 return Compile(mesh_state->data(), computation_proto,
205 tpu_program_group);
206 }));
207
208 Tensor output(DT_INT64, TensorShape({}));
209 output.scalar<int64>()() = uid;
210 ctx->set_output(0, output);
211
212 Tensor program_shape_output(DT_STRING, TensorShape({num_cores_per_replica}));
213 for (int64_t i = 0; i < num_cores_per_replica; ++i) {
214 xla::ProgramShapeProto program_shape =
215 hlo_metadata[i]->hlo_module().host_program_shape();
216 program_shape_output.vec<tstring>()(i) = program_shape.SerializeAsString();
217 }
218 ctx->set_output(1, program_shape_output);
219 }
220
221 XRTCompileOp::~XRTCompileOp() = default;
222
223 class XRTReleaseCompilationRefOp : public OpKernel {
224 public:
225 explicit XRTReleaseCompilationRefOp(OpKernelConstruction* ctx);
226 ~XRTReleaseCompilationRefOp() override;
227 XRTReleaseCompilationRefOp(const XRTReleaseCompilationRefOp&) = delete;
228 XRTReleaseCompilationRefOp& operator=(const XRTReleaseCompilationRefOp&) =
229 delete;
230
231 void Compute(OpKernelContext* ctx) override;
232 };
233
XRTReleaseCompilationRefOp(OpKernelConstruction * ctx)234 XRTReleaseCompilationRefOp::XRTReleaseCompilationRefOp(
235 OpKernelConstruction* ctx)
236 : OpKernel(ctx) {}
237
238 XRTReleaseCompilationRefOp::~XRTReleaseCompilationRefOp() = default;
239
Compute(OpKernelContext * ctx)240 void XRTReleaseCompilationRefOp::Compute(OpKernelContext* ctx) {
241 VLOG(1) << "XRTReleaseCompilationRefOp::Compute";
242 auto timed = monitoring::MakeTimed(xrt_metrics::GetReleaseCompilationCell());
243 ResourceMgr* rm = GetTPUConfigResourceMgr();
244 OP_REQUIRES(ctx, rm != nullptr, errors::Internal("No resource manager."));
245
246 // Process-wide cache of Tpu executables.
247 tpu::TpuCompilationCacheInterface* cache;
248 OP_REQUIRES_OK(ctx, rm->Lookup<tpu::TpuCompilationCacheInterface>(
249 rm->default_container(),
250 tpu::kCompilationCacheResourceName, &cache));
251 core::ScopedUnref cache_unref(cache);
252
253 const Tensor& keys_tensor = ctx->input(0);
254 auto flat_keys = keys_tensor.flat<int64>();
255 for (int64_t i = 0; i < flat_keys.size(); ++i) {
256 int64_t key = flat_keys(i);
257 OP_REQUIRES_OK(ctx, cache->Release(key));
258 VLOG(2) << "Released computation handle " << key;
259 }
260 }
261
262 REGISTER_KERNEL_BUILDER(Name("XRTCompile")
263 .Device(DEVICE_TPU_NODE)
264 .HostMemory("computation")
265 .HostMemory("handle"),
266 XRTCompileOp);
267
268 REGISTER_KERNEL_BUILDER(Name("XRTReleaseCompilationHandle")
269 .Device(DEVICE_TPU_NODE)
270 .HostMemory("handle"),
271 XRTReleaseCompilationRefOp);
272
273 } // namespace tensorflow
274