1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
16
17 #include <cstdint>
18
19 #include "tensorflow/c/tf_status.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/platform/refcount.h"
26 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
31 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
32 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
33 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
34 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
35 #include "tensorflow/core/tpu/tpu_api.h"
36 #include "tensorflow/core/tpu/tpu_configuration.h"
37 #include "tensorflow/core/tpu/tpu_defs.h"
38 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
39 #include "tensorflow/stream_executor/tpu/proto_helper.h"
40
41 namespace tensorflow {
42 namespace {
GetTpuMeshStateInterface(const ResourceMgr * rmgr,tpu::TpuMeshStateInterface ** state)43 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
44 tpu::TpuMeshStateInterface** state) {
45 if (!rmgr->Lookup(rmgr->default_container(),
46 tpu::kTpuMeshStateInterfaceResourceName, state)
47 .ok()) {
48 return errors::FailedPrecondition(
49 "The TPU system has not been initialized.");
50 }
51 return Status::OK();
52 }
53
CreateTpuFingerprintLookup(ResourceMgr * rmgr)54 Status CreateTpuFingerprintLookup(ResourceMgr* rmgr) {
55 VLOG(1) << "CreateTpuFingerprintLookup";
56 tpu::TpuFingerprintLookup* fingerprint_lookup;
57 TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<tpu::TpuFingerprintLookup>(
58 rmgr->default_container(), tpu::kFingerprintLookupResourceName,
59 &fingerprint_lookup, [&](tpu::TpuFingerprintLookup** new_lookup) {
60 *new_lookup = tpu::TpuFingerprintLookup::Create();
61 return Status::OK();
62 }));
63
64 core::ScopedUnref fingerprint_lookup_ref(fingerprint_lookup);
65 return Status::OK();
66 }
67
68 // Attempt to delete resource_name from resource_manager's default_container.
69 // Returns OK if the deletion succeeded, or if the resource was not found. Else
70 // return the deletion error.
71 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)72 Status DeleteIfExists(ResourceMgr* resource_manager,
73 const char* resource_name) {
74 VLOG(1) << "Removing resource " << resource_name << " if it exists";
75 Status status = resource_manager->Delete<ResourceT>(
76 resource_manager->default_container(), resource_name);
77 if (status.ok()) {
78 VLOG(1) << "Removed existing resource " << resource_name;
79 return Status::OK();
80 }
81 if (status.code() == error::NOT_FOUND) {
82 VLOG(1) << "No resource " << resource_name << " to remove";
83 return Status::OK();
84 }
85 VLOG(1) << "Error removing resource " << resource_name << " : " << status;
86 return status;
87 }
88 } // namespace
89
CreateTpuCompilationCache(ResourceMgr * rmgr,tpu::TpuCompilationCacheInterface ** compilation_cache)90 Status CreateTpuCompilationCache(
91 ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
92 return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
93 rmgr->default_container(), tpu::kCompilationCacheResourceName,
94 compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
95 *new_cache = tpu::GetCompilationCacheCreateFn()();
96 return Status::OK();
97 });
98 }
99
ConstructDevicesPerHost(OpKernelContext * ctx)100 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
101 OpKernelContext* ctx) {
102 std::vector<int32_t> num_devices_per_host;
103 int chips_per_host = -1;
104 for (int i = 0; i < ctx->num_inputs(); ++i) {
105 const Tensor& input_tensor = ctx->input(i);
106 if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
107 return errors::InvalidArgument("Input ", i,
108 " should be a scalar but has ",
109 input_tensor.dims(), " dimensions");
110 }
111 if (chips_per_host == -1) {
112 chips_per_host = input_tensor.scalar<int32_t>()();
113 } else {
114 if (chips_per_host != input_tensor.scalar<int32>()()) {
115 return errors::Internal("Host ", i, " has ",
116 input_tensor.scalar<int32>()(),
117 " TPU chips but host 0 has ", chips_per_host);
118 }
119 }
120 num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
121 }
122 return num_devices_per_host;
123 }
124
Compute(OpKernelContext * ctx)125 void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
126 VLOG(1) << "ConfigureDistributedTpuOp";
127 XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
128
129 xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
130 ConstructDevicesPerHost(ctx);
131 OP_REQUIRES_OK(ctx, num_devices_per_host.status());
132 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
133
134 // Create the subgraph compilation cache and put it in the local resource
135 // manager.
136 tpu::TpuCompilationCacheInterface* compilation_cache;
137 OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
138 core::ScopedUnref compilation_cache_ref(compilation_cache);
139
140 std::string host_config_output;
141 OP_REQUIRES_OK(
142 ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
143 &host_config_output));
144
145 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
146 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
147
148 auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
149 OP_REQUIRES_OK(
150 ctx, rmgr->Create(rmgr->default_container(),
151 tpu::kTpuMeshStateInterfaceResourceName, tpu_mesh));
152
153 Tensor* ctx_output;
154 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
155 ctx_output->scalar<tstring>()() = std::move(host_config_output);
156
157 OP_REQUIRES_OK(ctx, CreateTpuFingerprintLookup(rmgr));
158 VLOG(1) << "ConfigureDistributedTpuOp done";
159 }
160
Compute(OpKernelContext * ctx)161 void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
162 VLOG(1) << "WaitForDistributedTpuOp";
163 XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
164
165 size_t num_devices_per_host = -1;
166 size_t num_hosts = ctx->num_inputs();
167
168 for (int i = 0; i < ctx->num_inputs(); ++i) {
169 const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
170 OP_REQUIRES(
171 ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
172 errors::InvalidArgument("Input ", i, " should be a vector but has ",
173 host_ordinal_to_global_device_id_tensor.dims(),
174 " dimensions"));
175 }
176
177 std::vector<std::vector<int32_t>> mapping;
178 std::vector<int32_t*> mapping_arg;
179
180 mapping.resize(ctx->num_inputs());
181
182 for (int i = 0; i < ctx->num_inputs(); ++i) {
183 const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
184 const auto host_ordinal_to_global_device_id =
185 host_ordinal_to_global_device_id_tensor.flat<int>();
186 if (num_devices_per_host == -1) {
187 num_devices_per_host =
188 host_ordinal_to_global_device_id_tensor.dim_size(0);
189 } else {
190 OP_REQUIRES(ctx,
191 num_devices_per_host ==
192 host_ordinal_to_global_device_id_tensor.dim_size(0),
193 errors::Internal(
194 "Host ", i, " has ",
195 host_ordinal_to_global_device_id_tensor.dim_size(0),
196 " TPU devices but host 0 has ", num_devices_per_host));
197 }
198 for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
199 ++j) {
200 int32_t global_device_id = host_ordinal_to_global_device_id(j);
201 mapping[i].push_back(global_device_id);
202 }
203 mapping_arg.push_back(mapping[i].data());
204 }
205
206 tpu::TpuMeshStateInterface* mesh_state;
207 auto* rmgr = GetTPUConfigResourceMgr();
208 OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
209 core::ScopedUnref mesh_state_unref(mesh_state);
210
211 // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
212 // from a legacy library that may have staled. A candidate for cleanup.
213 TpuPodState* pod_state;
214 OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
215 core::ScopedUnref pod_state_unref(pod_state);
216
217 size_t tpu_topology_output_size;
218 char* tpu_topology_output = nullptr;
219 TF_Status* status = TF_NewStatus();
220 auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
221 TF_DeleteStatus(status);
222 tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
223 });
224
225 auto* mesh_common_state = mesh_state->mesh_common_state();
226
227 WaitForDistributedTpuOp_DoWork_Params params;
228 params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
229 params.priv = nullptr;
230 params.num_hosts = num_hosts;
231 params.num_cores_per_host = num_devices_per_host;
232 params.host_ordinal_to_global_core_id_map =
233 const_cast<const int32_t**>(mapping_arg.data());
234 params.tpu_mesh_common_state = mesh_common_state;
235 params.tpu_topology_output_size = &tpu_topology_output_size;
236 params.tpu_topology_output = &tpu_topology_output;
237 params.status = status;
238
239 tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(¶ms);
240
241 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
242
243 Tensor* ctx_output;
244 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
245 ctx_output->scalar<tstring>()() =
246 std::string(tpu_topology_output, tpu_topology_output_size);
247
248 VLOG(1) << "WaitForDistributedTpuOp done";
249 }
250
Compute(OpKernelContext * ctx)251 void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
252 VLOG(1) << "ShutdownDistributedTpuOp";
253 XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
254
255 auto* rmgr = GetTPUConfigResourceMgr();
256 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
257 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
258
259 OP_REQUIRES_OK(ctx,
260 DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
261 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
262 rmgr, tpu::kCompilationCacheResourceName));
263
264 VLOG(1) << "ShutdownDistributedTpuOp done";
265 }
266
Compute(OpKernelContext * ctx)267 void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
268 VLOG(1) << "InitializeHostForDistributedTpuOp";
269 XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
270
271 auto* rmgr = GetTPUConfigResourceMgr();
272 auto tpu_host_config = ctx->input(0).scalar<tstring>()();
273
274 bool is_master_worker =
275 tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
276 if (!is_master_worker) {
277 // Reset the mesh interface if we are not the master.
278 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
279 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
280 auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
281 OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
282 tpu::kTpuMeshStateInterfaceResourceName,
283 mesh_state_interface));
284 }
285
286 VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
287 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
288 rmgr, tpu::kCompiledProtoCacheResourceName));
289
290 if (enable_whole_mesh_compilations_) {
291 // If this is a whole mesh compilation mode, create the compilation cache,
292 // if missing.
293 tpu::TpuCompilationCacheInterface* compilation_cache;
294 OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
295 compilation_cache->Unref();
296 }
297
298 tpu::TpuCompilationCacheInterface* local_compilation_cache;
299 Status s = rmgr->Lookup(rmgr->default_container(),
300 tpu::kCompilationCacheResourceName,
301 &local_compilation_cache);
302 if (!s.ok()) {
303 local_compilation_cache = nullptr;
304 }
305
306 TF_Status* status = TF_NewStatus();
307 size_t device_id_output_size;
308 int32_t* device_id_output = nullptr;
309 auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
310 TF_DeleteStatus(status);
311 tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
312 });
313
314 InitializeHostForDistributedTpuOp_DoWork_Params params;
315 params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
316 params.priv = nullptr;
317 params.tpu_host_config_size = tpu_host_config.size();
318 params.tpu_host_config = tpu_host_config.data();
319 params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
320 params.is_master_worker = is_master_worker;
321 params.core_id_output_size = &device_id_output_size;
322 params.core_id_output = &device_id_output;
323 params.status = status;
324
325 tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(¶ms);
326 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
327
328 if (local_compilation_cache != nullptr) {
329 local_compilation_cache->Unref();
330
331 tpu::TpuCompilationCacheLookup* proto_lookup;
332 proto_lookup =
333 new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
334 OP_REQUIRES_OK(
335 ctx, rmgr->Create(rmgr->default_container(),
336 tpu::kCompiledProtoCacheResourceName, proto_lookup));
337 } else {
338 int64_t cache_size_bytes;
339 tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
340 &cache_size_bytes);
341
342 char* server_address_output = nullptr;
343 auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
344 tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
345 server_address_output);
346 });
347 size_t server_address_output_size;
348
349 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
350 params.struct_size =
351 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
352 params.priv = nullptr;
353 params.tpu_host_config_size = tpu_host_config.size();
354 params.tpu_host_config = tpu_host_config.data();
355 params.server_address_output_size = &server_address_output_size;
356 params.server_address_output = &server_address_output;
357 params.status = status;
358
359 tpu::OpsApiFn()
360 ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
361 ¶ms);
362 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
363
364 std::string server_address(server_address_output,
365 server_address_output_size);
366 tpu::TpuCompilationCacheLookup* proto_lookup =
367 new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
368 OP_REQUIRES_OK(
369 ctx, rmgr->Create(rmgr->default_container(),
370 tpu::kCompiledProtoCacheResourceName, proto_lookup));
371 }
372
373 Tensor* ctx_output;
374 OP_REQUIRES_OK(
375 ctx, ctx->allocate_output(
376 0, TensorShape({static_cast<long long>(device_id_output_size)}),
377 &ctx_output));
378
379 for (size_t i = 0; i < device_id_output_size; ++i) {
380 ctx_output->flat<int32>()(i) = device_id_output[i];
381 }
382
383 VLOG(1) << "InitializeHostForDistributedTpuOp done";
384 }
385
Compute(OpKernelContext * ctx)386 void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
387 VLOG(1) << "SetGlobalTPUArrayOp";
388 XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
389
390 auto tpu_topology = ctx->input(0).scalar<tstring>()();
391 TF_Status* status = TF_NewStatus();
392
393 tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
394 tpu_topology.data(), status);
395
396 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
397 TF_DeleteStatus(status);
398
399 VLOG(1) << "SetGlobalTPUArrayOp done";
400 }
401
Compute(OpKernelContext * ctx)402 void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
403 VLOG(1) << "DisconnectDistributedTpuChipsOp";
404 XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
405
406 TF_Status* status = TF_NewStatus();
407 int32_t number_of_chips_output = 0;
408
409 tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
410 &number_of_chips_output, status);
411
412 Tensor* ctx_output;
413 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
414 ctx_output->scalar<int32_t>()() = number_of_chips_output;
415
416 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
417 TF_DeleteStatus(status);
418
419 VLOG(1) << "DisconnectDistributedTpuChipsOp done";
420 }
421
422 // These ops execute on the TPU_SYSTEM device only.
423 REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
424 .Device(DEVICE_TPU_SYSTEM)
425 .HostMemory("output"),
426 ConfigureDistributedTpuOp);
427 REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
428 .Device(DEVICE_TPU_SYSTEM)
429 .HostMemory("inputs")
430 .HostMemory("topology"),
431 WaitForDistributedTpuOp);
432 REGISTER_KERNEL_BUILDER(
433 Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
434 ShutdownDistributedTpuOp);
435 REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
436 .Device(DEVICE_TPU_SYSTEM)
437 .HostMemory("input")
438 .HostMemory("tpu_ids"),
439 InitializeHostForDistributedTpuOp);
440 REGISTER_KERNEL_BUILDER(
441 Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
442 SetGlobalTPUArrayOp);
443 REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
444 .Device(DEVICE_TPU_SYSTEM)
445 .HostMemory("number_of_tpu_chips"),
446 DisconnectDistributedTpuChipsOp);
447
448 } // namespace tensorflow
449