1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
16
17 #include <cstdint>
18
19 #include "tensorflow/c/tf_status.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/platform/refcount.h"
26 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
31 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
32 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
33 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
34 #include "tensorflow/core/tpu/tpu_api.h"
35 #include "tensorflow/core/tpu/tpu_configuration.h"
36 #include "tensorflow/core/tpu/tpu_defs.h"
37 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
38 #include "tensorflow/stream_executor/tpu/proto_helper.h"
39
40 namespace tensorflow {
41 namespace {
GetTpuMeshStateInterface(const ResourceMgr * rmgr,tpu::TpuMeshStateInterface ** state)42 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
43 tpu::TpuMeshStateInterface** state) {
44 if (!rmgr->Lookup(rmgr->default_container(),
45 tpu::kTpuMeshStateInterfaceResourceName, state)
46 .ok()) {
47 return errors::FailedPrecondition(
48 "The TPU system has not been initialized.");
49 }
50 return Status::OK();
51 }
52
53 // Attempt to delete resource_name from resource_manager's default_container.
54 // Returns OK if the deletion succeeded, or if the resource was not found. Else
55 // return the deletion error.
56 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)57 Status DeleteIfExists(ResourceMgr* resource_manager,
58 const char* resource_name) {
59 VLOG(1) << "Removing resource " << resource_name << " if it exists";
60 Status status = resource_manager->Delete<ResourceT>(
61 resource_manager->default_container(), resource_name);
62 if (status.ok()) {
63 VLOG(1) << "Removed existing resource " << resource_name;
64 return Status::OK();
65 }
66 if (status.code() == error::NOT_FOUND) {
67 VLOG(1) << "No resource " << resource_name << " to remove";
68 return Status::OK();
69 }
70 VLOG(1) << "Error removing resource " << resource_name << " : " << status;
71 return status;
72 }
73 } // namespace
74
CreateTpuCompilationCache(ResourceMgr * rmgr,tpu::TpuCompilationCacheInterface ** compilation_cache)75 Status CreateTpuCompilationCache(
76 ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
77 return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
78 rmgr->default_container(), tpu::kCompilationCacheResourceName,
79 compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
80 *new_cache = tpu::GetCompilationCacheCreateFn()();
81 return Status::OK();
82 });
83 }
84
ConstructDevicesPerHost(OpKernelContext * ctx)85 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
86 OpKernelContext* ctx) {
87 std::vector<int32_t> num_devices_per_host;
88 int chips_per_host = -1;
89 for (int i = 0; i < ctx->num_inputs(); ++i) {
90 const Tensor& input_tensor = ctx->input(i);
91 if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
92 return errors::InvalidArgument("Input ", i,
93 " should be a scalar but has ",
94 input_tensor.dims(), " dimensions");
95 }
96 if (chips_per_host == -1) {
97 chips_per_host = input_tensor.scalar<int32_t>()();
98 } else {
99 if (chips_per_host != input_tensor.scalar<int32>()()) {
100 return errors::Internal("Host ", i, " has ",
101 input_tensor.scalar<int32>()(),
102 " TPU chips but host 0 has ", chips_per_host);
103 }
104 }
105 num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
106 }
107 return num_devices_per_host;
108 }
109
Compute(OpKernelContext * ctx)110 void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
111 VLOG(1) << "ConfigureDistributedTpuOp";
112 XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
113
114 xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
115 ConstructDevicesPerHost(ctx);
116 OP_REQUIRES_OK(ctx, num_devices_per_host.status());
117 ResourceMgr* rmgr = GetTPUConfigResourceMgr();
118
119 // Create the subgraph compilation cache and put it in the local resource
120 // manager.
121 tpu::TpuCompilationCacheInterface* compilation_cache;
122 OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
123 core::ScopedUnref compilation_cache_ref(compilation_cache);
124
125 std::string host_config_output;
126 OP_REQUIRES_OK(
127 ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
128 &host_config_output));
129
130 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
131 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
132
133 auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
134 OP_REQUIRES_OK(
135 ctx, rmgr->Create(rmgr->default_container(),
136 tpu::kTpuMeshStateInterfaceResourceName, tpu_mesh));
137
138 Tensor* ctx_output;
139 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
140 ctx_output->scalar<tstring>()() = std::move(host_config_output);
141
142 VLOG(1) << "ConfigureDistributedTpuOp done";
143 }
144
Compute(OpKernelContext * ctx)145 void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
146 VLOG(1) << "WaitForDistributedTpuOp";
147 XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
148
149 size_t num_devices_per_host = -1;
150 size_t num_hosts = ctx->num_inputs();
151
152 for (int i = 0; i < ctx->num_inputs(); ++i) {
153 const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
154 OP_REQUIRES(
155 ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
156 errors::InvalidArgument("Input ", i, " should be a vector but has ",
157 host_ordinal_to_global_device_id_tensor.dims(),
158 " dimensions"));
159 }
160
161 std::vector<std::vector<int32_t>> mapping;
162 std::vector<int32_t*> mapping_arg;
163
164 mapping.resize(ctx->num_inputs());
165
166 for (int i = 0; i < ctx->num_inputs(); ++i) {
167 const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
168 const auto host_ordinal_to_global_device_id =
169 host_ordinal_to_global_device_id_tensor.flat<int>();
170 if (num_devices_per_host == -1) {
171 num_devices_per_host =
172 host_ordinal_to_global_device_id_tensor.dim_size(0);
173 } else {
174 OP_REQUIRES(ctx,
175 num_devices_per_host ==
176 host_ordinal_to_global_device_id_tensor.dim_size(0),
177 errors::Internal(
178 "Host ", i, " has ",
179 host_ordinal_to_global_device_id_tensor.dim_size(0),
180 " TPU devices but host 0 has ", num_devices_per_host));
181 }
182 for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
183 ++j) {
184 int32_t global_device_id = host_ordinal_to_global_device_id(j);
185 mapping[i].push_back(global_device_id);
186 }
187 mapping_arg.push_back(mapping[i].data());
188 }
189
190 tpu::TpuMeshStateInterface* mesh_state;
191 auto* rmgr = GetTPUConfigResourceMgr();
192 OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
193 core::ScopedUnref mesh_state_unref(mesh_state);
194
195 // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
196 // from a legacy library that may have staled. A candidate for cleanup.
197 TpuPodState* pod_state;
198 OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
199 core::ScopedUnref pod_state_unref(pod_state);
200
201 size_t tpu_topology_output_size;
202 char* tpu_topology_output = nullptr;
203 TF_Status* status = TF_NewStatus();
204 auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
205 TF_DeleteStatus(status);
206 tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
207 });
208
209 auto* mesh_common_state = mesh_state->mesh_common_state();
210
211 WaitForDistributedTpuOp_DoWork_Params params;
212 params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
213 params.priv = nullptr;
214 params.num_hosts = num_hosts;
215 params.num_cores_per_host = num_devices_per_host;
216 params.host_ordinal_to_global_core_id_map =
217 const_cast<const int32_t**>(mapping_arg.data());
218 params.tpu_mesh_common_state = mesh_common_state;
219 params.tpu_topology_output_size = &tpu_topology_output_size;
220 params.tpu_topology_output = &tpu_topology_output;
221 params.status = status;
222
223 tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(¶ms);
224
225 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
226
227 Tensor* ctx_output;
228 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
229 ctx_output->scalar<tstring>()() =
230 std::string(tpu_topology_output, tpu_topology_output_size);
231
232 VLOG(1) << "WaitForDistributedTpuOp done";
233 }
234
Compute(OpKernelContext * ctx)235 void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
236 VLOG(1) << "ShutdownDistributedTpuOp";
237 XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
238
239 auto* rmgr = GetTPUConfigResourceMgr();
240 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
241 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
242
243 OP_REQUIRES_OK(ctx,
244 DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
245 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
246 rmgr, tpu::kCompilationCacheResourceName));
247
248 VLOG(1) << "ShutdownDistributedTpuOp done";
249 }
250
Compute(OpKernelContext * ctx)251 void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
252 VLOG(1) << "InitializeHostForDistributedTpuOp";
253 XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
254
255 auto* rmgr = GetTPUConfigResourceMgr();
256 auto tpu_host_config = ctx->input(0).scalar<tstring>()();
257
258 bool is_master_worker =
259 tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
260 if (!is_master_worker) {
261 // Reset the mesh interface if we are not the master.
262 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
263 rmgr, tpu::kTpuMeshStateInterfaceResourceName));
264 auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
265 OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
266 tpu::kTpuMeshStateInterfaceResourceName,
267 mesh_state_interface));
268 }
269
270 VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
271 OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
272 rmgr, tpu::kCompiledProtoCacheResourceName));
273
274 if (enable_whole_mesh_compilations_) {
275 // If this is a whole mesh compilation mode, create the compilation cache,
276 // if missing.
277 tpu::TpuCompilationCacheInterface* compilation_cache;
278 OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
279 compilation_cache->Unref();
280 }
281
282 tpu::TpuCompilationCacheInterface* local_compilation_cache;
283 Status s = rmgr->Lookup(rmgr->default_container(),
284 tpu::kCompilationCacheResourceName,
285 &local_compilation_cache);
286 if (!s.ok()) {
287 local_compilation_cache = nullptr;
288 }
289
290 TF_Status* status = TF_NewStatus();
291 size_t device_id_output_size;
292 int32_t* device_id_output = nullptr;
293 auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
294 TF_DeleteStatus(status);
295 tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
296 });
297
298 InitializeHostForDistributedTpuOp_DoWork_Params params;
299 params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
300 params.priv = nullptr;
301 params.tpu_host_config_size = tpu_host_config.size();
302 params.tpu_host_config = tpu_host_config.data();
303 params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
304 params.is_master_worker = is_master_worker;
305 params.core_id_output_size = &device_id_output_size;
306 params.core_id_output = &device_id_output;
307 params.status = status;
308
309 tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(¶ms);
310 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
311
312 if (local_compilation_cache != nullptr) {
313 local_compilation_cache->Unref();
314
315 tpu::TpuCompilationCacheLookup* proto_lookup;
316 proto_lookup =
317 new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
318 OP_REQUIRES_OK(
319 ctx, rmgr->Create(rmgr->default_container(),
320 tpu::kCompiledProtoCacheResourceName, proto_lookup));
321 } else {
322 int64_t cache_size_bytes;
323 tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
324 &cache_size_bytes);
325
326 char* server_address_output = nullptr;
327 auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
328 tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
329 server_address_output);
330 });
331 size_t server_address_output_size;
332
333 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
334 params.struct_size =
335 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
336 params.priv = nullptr;
337 params.tpu_host_config_size = tpu_host_config.size();
338 params.tpu_host_config = tpu_host_config.data();
339 params.server_address_output_size = &server_address_output_size;
340 params.server_address_output = &server_address_output;
341 params.status = status;
342
343 tpu::OpsApiFn()
344 ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
345 ¶ms);
346 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
347
348 std::string server_address(server_address_output,
349 server_address_output_size);
350 tpu::TpuCompilationCacheLookup* proto_lookup =
351 new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
352 OP_REQUIRES_OK(
353 ctx, rmgr->Create(rmgr->default_container(),
354 tpu::kCompiledProtoCacheResourceName, proto_lookup));
355 }
356
357 Tensor* ctx_output;
358 OP_REQUIRES_OK(
359 ctx, ctx->allocate_output(
360 0, TensorShape({static_cast<long long>(device_id_output_size)}),
361 &ctx_output));
362
363 for (size_t i = 0; i < device_id_output_size; ++i) {
364 ctx_output->flat<int32>()(i) = device_id_output[i];
365 }
366
367 VLOG(1) << "InitializeHostForDistributedTpuOp done";
368 }
369
Compute(OpKernelContext * ctx)370 void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
371 VLOG(1) << "SetGlobalTPUArrayOp";
372 XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
373
374 auto tpu_topology = ctx->input(0).scalar<tstring>()();
375 TF_Status* status = TF_NewStatus();
376
377 tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
378 tpu_topology.data(), status);
379
380 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
381 TF_DeleteStatus(status);
382
383 VLOG(1) << "SetGlobalTPUArrayOp done";
384 }
385
Compute(OpKernelContext * ctx)386 void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
387 VLOG(1) << "DisconnectDistributedTpuChipsOp";
388 XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
389
390 TF_Status* status = TF_NewStatus();
391 int32_t number_of_chips_output = 0;
392
393 tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
394 &number_of_chips_output, status);
395
396 Tensor* ctx_output;
397 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
398 ctx_output->scalar<int32_t>()() = number_of_chips_output;
399
400 OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
401 TF_DeleteStatus(status);
402
403 VLOG(1) << "DisconnectDistributedTpuChipsOp done";
404 }
405
406 // These ops execute on the TPU_SYSTEM device only.
407 REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
408 .Device(DEVICE_TPU_SYSTEM)
409 .HostMemory("output"),
410 ConfigureDistributedTpuOp);
411 REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
412 .Device(DEVICE_TPU_SYSTEM)
413 .HostMemory("inputs")
414 .HostMemory("topology"),
415 WaitForDistributedTpuOp);
416 REGISTER_KERNEL_BUILDER(
417 Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
418 ShutdownDistributedTpuOp);
419 REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
420 .Device(DEVICE_TPU_SYSTEM)
421 .HostMemory("input")
422 .HostMemory("tpu_ids"),
423 InitializeHostForDistributedTpuOp);
424 REGISTER_KERNEL_BUILDER(
425 Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
426 SetGlobalTPUArrayOp);
427 REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
428 .Device(DEVICE_TPU_SYSTEM)
429 .HostMemory("number_of_tpu_chips"),
430 DisconnectDistributedTpuChipsOp);
431
432 } // namespace tensorflow
433