• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
16 
17 #include <cstdint>
18 
19 #include "tensorflow/c/tf_status.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/platform/refcount.h"
26 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
31 #include "tensorflow/core/tpu/kernels/tpu_fingerprint_lookup.h"
32 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
33 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
34 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
35 #include "tensorflow/core/tpu/tpu_api.h"
36 #include "tensorflow/core/tpu/tpu_configuration.h"
37 #include "tensorflow/core/tpu/tpu_defs.h"
38 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
39 #include "tensorflow/stream_executor/tpu/proto_helper.h"
40 
41 namespace tensorflow {
42 namespace {
GetTpuMeshStateInterface(const ResourceMgr * rmgr,tpu::TpuMeshStateInterface ** state)43 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
44                                 tpu::TpuMeshStateInterface** state) {
45   if (!rmgr->Lookup(rmgr->default_container(),
46                     tpu::kTpuMeshStateInterfaceResourceName, state)
47            .ok()) {
48     return errors::FailedPrecondition(
49         "The TPU system has not been initialized.");
50   }
51   return Status::OK();
52 }
53 
CreateTpuFingerprintLookup(ResourceMgr * rmgr)54 Status CreateTpuFingerprintLookup(ResourceMgr* rmgr) {
55   VLOG(1) << "CreateTpuFingerprintLookup";
56   tpu::TpuFingerprintLookup* fingerprint_lookup;
57   TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<tpu::TpuFingerprintLookup>(
58       rmgr->default_container(), tpu::kFingerprintLookupResourceName,
59       &fingerprint_lookup, [&](tpu::TpuFingerprintLookup** new_lookup) {
60         *new_lookup = tpu::TpuFingerprintLookup::Create();
61         return Status::OK();
62       }));
63 
64   core::ScopedUnref fingerprint_lookup_ref(fingerprint_lookup);
65   return Status::OK();
66 }
67 
68 // Attempt to delete resource_name from resource_manager's default_container.
69 // Returns OK if the deletion succeeded, or if the resource was not found. Else
70 // return the deletion error.
71 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)72 Status DeleteIfExists(ResourceMgr* resource_manager,
73                       const char* resource_name) {
74   VLOG(1) << "Removing resource " << resource_name << " if it exists";
75   Status status = resource_manager->Delete<ResourceT>(
76       resource_manager->default_container(), resource_name);
77   if (status.ok()) {
78     VLOG(1) << "Removed existing resource " << resource_name;
79     return Status::OK();
80   }
81   if (status.code() == error::NOT_FOUND) {
82     VLOG(1) << "No resource " << resource_name << " to remove";
83     return Status::OK();
84   }
85   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
86   return status;
87 }
88 }  // namespace
89 
CreateTpuCompilationCache(ResourceMgr * rmgr,tpu::TpuCompilationCacheInterface ** compilation_cache)90 Status CreateTpuCompilationCache(
91     ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
92   return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
93       rmgr->default_container(), tpu::kCompilationCacheResourceName,
94       compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
95         *new_cache = tpu::GetCompilationCacheCreateFn()();
96         return Status::OK();
97       });
98 }
99 
ConstructDevicesPerHost(OpKernelContext * ctx)100 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
101     OpKernelContext* ctx) {
102   std::vector<int32_t> num_devices_per_host;
103   int chips_per_host = -1;
104   for (int i = 0; i < ctx->num_inputs(); ++i) {
105     const Tensor& input_tensor = ctx->input(i);
106     if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
107       return errors::InvalidArgument("Input ", i,
108                                      " should be a scalar but has ",
109                                      input_tensor.dims(), " dimensions");
110     }
111     if (chips_per_host == -1) {
112       chips_per_host = input_tensor.scalar<int32_t>()();
113     } else {
114       if (chips_per_host != input_tensor.scalar<int32>()()) {
115         return errors::Internal("Host ", i, " has ",
116                                 input_tensor.scalar<int32>()(),
117                                 " TPU chips but host 0 has ", chips_per_host);
118       }
119     }
120     num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
121   }
122   return num_devices_per_host;
123 }
124 
Compute(OpKernelContext * ctx)125 void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
126   VLOG(1) << "ConfigureDistributedTpuOp";
127   XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
128 
129   xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
130       ConstructDevicesPerHost(ctx);
131   OP_REQUIRES_OK(ctx, num_devices_per_host.status());
132   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
133 
134   // Create the subgraph compilation cache and put it in the local resource
135   // manager.
136   tpu::TpuCompilationCacheInterface* compilation_cache;
137   OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
138   core::ScopedUnref compilation_cache_ref(compilation_cache);
139 
140   std::string host_config_output;
141   OP_REQUIRES_OK(
142       ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
143                                 &host_config_output));
144 
145   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
146                           rmgr, tpu::kTpuMeshStateInterfaceResourceName));
147 
148   auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
149   OP_REQUIRES_OK(
150       ctx, rmgr->Create(rmgr->default_container(),
151                         tpu::kTpuMeshStateInterfaceResourceName, tpu_mesh));
152 
153   Tensor* ctx_output;
154   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
155   ctx_output->scalar<tstring>()() = std::move(host_config_output);
156 
157   OP_REQUIRES_OK(ctx, CreateTpuFingerprintLookup(rmgr));
158   VLOG(1) << "ConfigureDistributedTpuOp done";
159 }
160 
Compute(OpKernelContext * ctx)161 void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
162   VLOG(1) << "WaitForDistributedTpuOp";
163   XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
164 
165   size_t num_devices_per_host = -1;
166   size_t num_hosts = ctx->num_inputs();
167 
168   for (int i = 0; i < ctx->num_inputs(); ++i) {
169     const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
170     OP_REQUIRES(
171         ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
172         errors::InvalidArgument("Input ", i, " should be a vector but has ",
173                                 host_ordinal_to_global_device_id_tensor.dims(),
174                                 " dimensions"));
175   }
176 
177   std::vector<std::vector<int32_t>> mapping;
178   std::vector<int32_t*> mapping_arg;
179 
180   mapping.resize(ctx->num_inputs());
181 
182   for (int i = 0; i < ctx->num_inputs(); ++i) {
183     const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
184     const auto host_ordinal_to_global_device_id =
185         host_ordinal_to_global_device_id_tensor.flat<int>();
186     if (num_devices_per_host == -1) {
187       num_devices_per_host =
188           host_ordinal_to_global_device_id_tensor.dim_size(0);
189     } else {
190       OP_REQUIRES(ctx,
191                   num_devices_per_host ==
192                       host_ordinal_to_global_device_id_tensor.dim_size(0),
193                   errors::Internal(
194                       "Host ", i, " has ",
195                       host_ordinal_to_global_device_id_tensor.dim_size(0),
196                       " TPU devices but host 0 has ", num_devices_per_host));
197     }
198     for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
199          ++j) {
200       int32_t global_device_id = host_ordinal_to_global_device_id(j);
201       mapping[i].push_back(global_device_id);
202     }
203     mapping_arg.push_back(mapping[i].data());
204   }
205 
206   tpu::TpuMeshStateInterface* mesh_state;
207   auto* rmgr = GetTPUConfigResourceMgr();
208   OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
209   core::ScopedUnref mesh_state_unref(mesh_state);
210 
211   // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
212   // from a legacy library that may have staled. A candidate for cleanup.
213   TpuPodState* pod_state;
214   OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
215   core::ScopedUnref pod_state_unref(pod_state);
216 
217   size_t tpu_topology_output_size;
218   char* tpu_topology_output = nullptr;
219   TF_Status* status = TF_NewStatus();
220   auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
221     TF_DeleteStatus(status);
222     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
223   });
224 
225   auto* mesh_common_state = mesh_state->mesh_common_state();
226 
227   WaitForDistributedTpuOp_DoWork_Params params;
228   params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
229   params.priv = nullptr;
230   params.num_hosts = num_hosts;
231   params.num_cores_per_host = num_devices_per_host;
232   params.host_ordinal_to_global_core_id_map =
233       const_cast<const int32_t**>(mapping_arg.data());
234   params.tpu_mesh_common_state = mesh_common_state;
235   params.tpu_topology_output_size = &tpu_topology_output_size;
236   params.tpu_topology_output = &tpu_topology_output;
237   params.status = status;
238 
239   tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(&params);
240 
241   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
242 
243   Tensor* ctx_output;
244   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
245   ctx_output->scalar<tstring>()() =
246       std::string(tpu_topology_output, tpu_topology_output_size);
247 
248   VLOG(1) << "WaitForDistributedTpuOp done";
249 }
250 
Compute(OpKernelContext * ctx)251 void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
252   VLOG(1) << "ShutdownDistributedTpuOp";
253   XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
254 
255   auto* rmgr = GetTPUConfigResourceMgr();
256   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
257                           rmgr, tpu::kTpuMeshStateInterfaceResourceName));
258 
259   OP_REQUIRES_OK(ctx,
260                  DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
261   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
262                           rmgr, tpu::kCompilationCacheResourceName));
263 
264   VLOG(1) << "ShutdownDistributedTpuOp done";
265 }
266 
Compute(OpKernelContext * ctx)267 void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
268   VLOG(1) << "InitializeHostForDistributedTpuOp";
269   XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
270 
271   auto* rmgr = GetTPUConfigResourceMgr();
272   auto tpu_host_config = ctx->input(0).scalar<tstring>()();
273 
274   bool is_master_worker =
275       tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
276   if (!is_master_worker) {
277     // Reset the mesh interface if we are not the master.
278     OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
279                             rmgr, tpu::kTpuMeshStateInterfaceResourceName));
280     auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
281     OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
282                                      tpu::kTpuMeshStateInterfaceResourceName,
283                                      mesh_state_interface));
284   }
285 
286   VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
287   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
288                           rmgr, tpu::kCompiledProtoCacheResourceName));
289 
290   if (enable_whole_mesh_compilations_) {
291     // If this is a whole mesh compilation mode, create the compilation cache,
292     // if missing.
293     tpu::TpuCompilationCacheInterface* compilation_cache;
294     OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
295     compilation_cache->Unref();
296   }
297 
298   tpu::TpuCompilationCacheInterface* local_compilation_cache;
299   Status s = rmgr->Lookup(rmgr->default_container(),
300                           tpu::kCompilationCacheResourceName,
301                           &local_compilation_cache);
302   if (!s.ok()) {
303     local_compilation_cache = nullptr;
304   }
305 
306   TF_Status* status = TF_NewStatus();
307   size_t device_id_output_size;
308   int32_t* device_id_output = nullptr;
309   auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
310     TF_DeleteStatus(status);
311     tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
312   });
313 
314   InitializeHostForDistributedTpuOp_DoWork_Params params;
315   params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
316   params.priv = nullptr;
317   params.tpu_host_config_size = tpu_host_config.size();
318   params.tpu_host_config = tpu_host_config.data();
319   params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
320   params.is_master_worker = is_master_worker;
321   params.core_id_output_size = &device_id_output_size;
322   params.core_id_output = &device_id_output;
323   params.status = status;
324 
325   tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(&params);
326   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
327 
328   if (local_compilation_cache != nullptr) {
329     local_compilation_cache->Unref();
330 
331     tpu::TpuCompilationCacheLookup* proto_lookup;
332     proto_lookup =
333         new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
334     OP_REQUIRES_OK(
335         ctx, rmgr->Create(rmgr->default_container(),
336                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
337   } else {
338     int64_t cache_size_bytes;
339     tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
340         &cache_size_bytes);
341 
342     char* server_address_output = nullptr;
343     auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
344       tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
345           server_address_output);
346     });
347     size_t server_address_output_size;
348 
349     TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
350     params.struct_size =
351         TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
352     params.priv = nullptr;
353     params.tpu_host_config_size = tpu_host_config.size();
354     params.tpu_host_config = tpu_host_config.data();
355     params.server_address_output_size = &server_address_output_size;
356     params.server_address_output = &server_address_output;
357     params.status = status;
358 
359     tpu::OpsApiFn()
360         ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
361             &params);
362     OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
363 
364     std::string server_address(server_address_output,
365                                server_address_output_size);
366     tpu::TpuCompilationCacheLookup* proto_lookup =
367         new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
368     OP_REQUIRES_OK(
369         ctx, rmgr->Create(rmgr->default_container(),
370                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
371   }
372 
373   Tensor* ctx_output;
374   OP_REQUIRES_OK(
375       ctx, ctx->allocate_output(
376                0, TensorShape({static_cast<long long>(device_id_output_size)}),
377                &ctx_output));
378 
379   for (size_t i = 0; i < device_id_output_size; ++i) {
380     ctx_output->flat<int32>()(i) = device_id_output[i];
381   }
382 
383   VLOG(1) << "InitializeHostForDistributedTpuOp done";
384 }
385 
Compute(OpKernelContext * ctx)386 void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
387   VLOG(1) << "SetGlobalTPUArrayOp";
388   XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
389 
390   auto tpu_topology = ctx->input(0).scalar<tstring>()();
391   TF_Status* status = TF_NewStatus();
392 
393   tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
394                                                 tpu_topology.data(), status);
395 
396   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
397   TF_DeleteStatus(status);
398 
399   VLOG(1) << "SetGlobalTPUArrayOp done";
400 }
401 
Compute(OpKernelContext * ctx)402 void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
403   VLOG(1) << "DisconnectDistributedTpuChipsOp";
404   XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
405 
406   TF_Status* status = TF_NewStatus();
407   int32_t number_of_chips_output = 0;
408 
409   tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
410       &number_of_chips_output, status);
411 
412   Tensor* ctx_output;
413   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
414   ctx_output->scalar<int32_t>()() = number_of_chips_output;
415 
416   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
417   TF_DeleteStatus(status);
418 
419   VLOG(1) << "DisconnectDistributedTpuChipsOp done";
420 }
421 
422 // These ops execute on the TPU_SYSTEM device only.
423 REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
424                             .Device(DEVICE_TPU_SYSTEM)
425                             .HostMemory("output"),
426                         ConfigureDistributedTpuOp);
427 REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
428                             .Device(DEVICE_TPU_SYSTEM)
429                             .HostMemory("inputs")
430                             .HostMemory("topology"),
431                         WaitForDistributedTpuOp);
432 REGISTER_KERNEL_BUILDER(
433     Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
434     ShutdownDistributedTpuOp);
435 REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
436                             .Device(DEVICE_TPU_SYSTEM)
437                             .HostMemory("input")
438                             .HostMemory("tpu_ids"),
439                         InitializeHostForDistributedTpuOp);
440 REGISTER_KERNEL_BUILDER(
441     Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
442     SetGlobalTPUArrayOp);
443 REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
444                             .Device(DEVICE_TPU_SYSTEM)
445                             .HostMemory("number_of_tpu_chips"),
446                         DisconnectDistributedTpuChipsOp);
447 
448 }  // namespace tensorflow
449