• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_configuration_ops.h"
16 
17 #include <cstdint>
18 
19 #include "tensorflow/c/tf_status.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/platform/refcount.h"
26 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_factory.h"
27 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
28 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h"
29 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
30 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h"
31 #include "tensorflow/core/tpu/kernels/tpu_mesh_state_interface.h"
32 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
33 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
34 #include "tensorflow/core/tpu/tpu_api.h"
35 #include "tensorflow/core/tpu/tpu_configuration.h"
36 #include "tensorflow/core/tpu/tpu_defs.h"
37 #include "tensorflow/core/tpu/tpu_ops_c_api.h"
38 #include "tensorflow/stream_executor/tpu/proto_helper.h"
39 
40 namespace tensorflow {
41 namespace {
GetTpuMeshStateInterface(const ResourceMgr * rmgr,tpu::TpuMeshStateInterface ** state)42 Status GetTpuMeshStateInterface(const ResourceMgr* rmgr,
43                                 tpu::TpuMeshStateInterface** state) {
44   if (!rmgr->Lookup(rmgr->default_container(),
45                     tpu::kTpuMeshStateInterfaceResourceName, state)
46            .ok()) {
47     return errors::FailedPrecondition(
48         "The TPU system has not been initialized.");
49   }
50   return Status::OK();
51 }
52 
53 // Attempt to delete resource_name from resource_manager's default_container.
54 // Returns OK if the deletion succeeded, or if the resource was not found. Else
55 // return the deletion error.
56 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)57 Status DeleteIfExists(ResourceMgr* resource_manager,
58                       const char* resource_name) {
59   VLOG(1) << "Removing resource " << resource_name << " if it exists";
60   Status status = resource_manager->Delete<ResourceT>(
61       resource_manager->default_container(), resource_name);
62   if (status.ok()) {
63     VLOG(1) << "Removed existing resource " << resource_name;
64     return Status::OK();
65   }
66   if (status.code() == error::NOT_FOUND) {
67     VLOG(1) << "No resource " << resource_name << " to remove";
68     return Status::OK();
69   }
70   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
71   return status;
72 }
73 }  // namespace
74 
CreateTpuCompilationCache(ResourceMgr * rmgr,tpu::TpuCompilationCacheInterface ** compilation_cache)75 Status CreateTpuCompilationCache(
76     ResourceMgr* rmgr, tpu::TpuCompilationCacheInterface** compilation_cache) {
77   return rmgr->LookupOrCreate<tpu::TpuCompilationCacheInterface>(
78       rmgr->default_container(), tpu::kCompilationCacheResourceName,
79       compilation_cache, [&](tpu::TpuCompilationCacheInterface** new_cache) {
80         *new_cache = tpu::GetCompilationCacheCreateFn()();
81         return Status::OK();
82       });
83 }
84 
ConstructDevicesPerHost(OpKernelContext * ctx)85 xla::StatusOr<std::vector<int32_t>> ConstructDevicesPerHost(
86     OpKernelContext* ctx) {
87   std::vector<int32_t> num_devices_per_host;
88   int chips_per_host = -1;
89   for (int i = 0; i < ctx->num_inputs(); ++i) {
90     const Tensor& input_tensor = ctx->input(i);
91     if (!TensorShapeUtils::IsScalar(input_tensor.shape())) {
92       return errors::InvalidArgument("Input ", i,
93                                      " should be a scalar but has ",
94                                      input_tensor.dims(), " dimensions");
95     }
96     if (chips_per_host == -1) {
97       chips_per_host = input_tensor.scalar<int32_t>()();
98     } else {
99       if (chips_per_host != input_tensor.scalar<int32>()()) {
100         return errors::Internal("Host ", i, " has ",
101                                 input_tensor.scalar<int32>()(),
102                                 " TPU chips but host 0 has ", chips_per_host);
103       }
104     }
105     num_devices_per_host.push_back(input_tensor.scalar<int32_t>()());
106   }
107   return num_devices_per_host;
108 }
109 
Compute(OpKernelContext * ctx)110 void ConfigureDistributedTpuOp::Compute(OpKernelContext* ctx) {
111   VLOG(1) << "ConfigureDistributedTpuOp";
112   XLA_SCOPED_LOGGING_TIMER("ConfigureDistributedTpuOp");
113 
114   xla::StatusOr<std::vector<int32_t>> num_devices_per_host =
115       ConstructDevicesPerHost(ctx);
116   OP_REQUIRES_OK(ctx, num_devices_per_host.status());
117   ResourceMgr* rmgr = GetTPUConfigResourceMgr();
118 
119   // Create the subgraph compilation cache and put it in the local resource
120   // manager.
121   tpu::TpuCompilationCacheInterface* compilation_cache;
122   OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
123   core::ScopedUnref compilation_cache_ref(compilation_cache);
124 
125   std::string host_config_output;
126   OP_REQUIRES_OK(
127       ctx, ConstructTpuPodState(rmgr, *num_devices_per_host, compilation_cache,
128                                 &host_config_output));
129 
130   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
131                           rmgr, tpu::kTpuMeshStateInterfaceResourceName));
132 
133   auto* tpu_mesh = tpu::TpuMeshStateInterface::Create();
134   OP_REQUIRES_OK(
135       ctx, rmgr->Create(rmgr->default_container(),
136                         tpu::kTpuMeshStateInterfaceResourceName, tpu_mesh));
137 
138   Tensor* ctx_output;
139   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
140   ctx_output->scalar<tstring>()() = std::move(host_config_output);
141 
142   VLOG(1) << "ConfigureDistributedTpuOp done";
143 }
144 
Compute(OpKernelContext * ctx)145 void WaitForDistributedTpuOp::Compute(OpKernelContext* ctx) {
146   VLOG(1) << "WaitForDistributedTpuOp";
147   XLA_SCOPED_LOGGING_TIMER("WaitForDistributedTpuOp");
148 
149   size_t num_devices_per_host = -1;
150   size_t num_hosts = ctx->num_inputs();
151 
152   for (int i = 0; i < ctx->num_inputs(); ++i) {
153     const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
154     OP_REQUIRES(
155         ctx, host_ordinal_to_global_device_id_tensor.dims() == 1,
156         errors::InvalidArgument("Input ", i, " should be a vector but has ",
157                                 host_ordinal_to_global_device_id_tensor.dims(),
158                                 " dimensions"));
159   }
160 
161   std::vector<std::vector<int32_t>> mapping;
162   std::vector<int32_t*> mapping_arg;
163 
164   mapping.resize(ctx->num_inputs());
165 
166   for (int i = 0; i < ctx->num_inputs(); ++i) {
167     const Tensor& host_ordinal_to_global_device_id_tensor = ctx->input(i);
168     const auto host_ordinal_to_global_device_id =
169         host_ordinal_to_global_device_id_tensor.flat<int>();
170     if (num_devices_per_host == -1) {
171       num_devices_per_host =
172           host_ordinal_to_global_device_id_tensor.dim_size(0);
173     } else {
174       OP_REQUIRES(ctx,
175                   num_devices_per_host ==
176                       host_ordinal_to_global_device_id_tensor.dim_size(0),
177                   errors::Internal(
178                       "Host ", i, " has ",
179                       host_ordinal_to_global_device_id_tensor.dim_size(0),
180                       " TPU devices but host 0 has ", num_devices_per_host));
181     }
182     for (int j = 0; j < host_ordinal_to_global_device_id_tensor.dim_size(0);
183          ++j) {
184       int32_t global_device_id = host_ordinal_to_global_device_id(j);
185       mapping[i].push_back(global_device_id);
186     }
187     mapping_arg.push_back(mapping[i].data());
188   }
189 
190   tpu::TpuMeshStateInterface* mesh_state;
191   auto* rmgr = GetTPUConfigResourceMgr();
192   OP_REQUIRES_OK(ctx, GetTpuMeshStateInterface(rmgr, &mesh_state));
193   core::ScopedUnref mesh_state_unref(mesh_state);
194 
195   // TODO(b/166858751): this code to check if `TpuPodState` exists is ported
196   // from a legacy library that may have staled. A candidate for cleanup.
197   TpuPodState* pod_state;
198   OP_REQUIRES_OK(ctx, GetTPUPodState(rmgr, &pod_state));
199   core::ScopedUnref pod_state_unref(pod_state);
200 
201   size_t tpu_topology_output_size;
202   char* tpu_topology_output = nullptr;
203   TF_Status* status = TF_NewStatus();
204   auto cleanup = xla::MakeCleanup([&status, &tpu_topology_output]() {
205     TF_DeleteStatus(status);
206     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(tpu_topology_output);
207   });
208 
209   auto* mesh_common_state = mesh_state->mesh_common_state();
210 
211   WaitForDistributedTpuOp_DoWork_Params params;
212   params.struct_size = WaitForDistributedTpuOp_DoWork_Params_SIZE;
213   params.priv = nullptr;
214   params.num_hosts = num_hosts;
215   params.num_cores_per_host = num_devices_per_host;
216   params.host_ordinal_to_global_core_id_map =
217       const_cast<const int32_t**>(mapping_arg.data());
218   params.tpu_mesh_common_state = mesh_common_state;
219   params.tpu_topology_output_size = &tpu_topology_output_size;
220   params.tpu_topology_output = &tpu_topology_output;
221   params.status = status;
222 
223   tpu::OpsApiFn()->WaitForDistributedTpuOp_DoWorkFn(&params);
224 
225   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
226 
227   Tensor* ctx_output;
228   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
229   ctx_output->scalar<tstring>()() =
230       std::string(tpu_topology_output, tpu_topology_output_size);
231 
232   VLOG(1) << "WaitForDistributedTpuOp done";
233 }
234 
Compute(OpKernelContext * ctx)235 void ShutdownDistributedTpuOp::Compute(OpKernelContext* ctx) {
236   VLOG(1) << "ShutdownDistributedTpuOp";
237   XLA_SCOPED_LOGGING_TIMER("ShutdownDistributedTpuOp");
238 
239   auto* rmgr = GetTPUConfigResourceMgr();
240   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
241                           rmgr, tpu::kTpuMeshStateInterfaceResourceName));
242 
243   OP_REQUIRES_OK(ctx,
244                  DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
245   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheInterface>(
246                           rmgr, tpu::kCompilationCacheResourceName));
247 
248   VLOG(1) << "ShutdownDistributedTpuOp done";
249 }
250 
Compute(OpKernelContext * ctx)251 void InitializeHostForDistributedTpuOp::Compute(OpKernelContext* ctx) {
252   VLOG(1) << "InitializeHostForDistributedTpuOp";
253   XLA_SCOPED_LOGGING_TIMER("InitializeHostForDistributedTpuOp");
254 
255   auto* rmgr = GetTPUConfigResourceMgr();
256   auto tpu_host_config = ctx->input(0).scalar<tstring>()();
257 
258   bool is_master_worker =
259       tpu::OpsApiFn()->TpuConfigurationApi_HasTPUPodStateFn();
260   if (!is_master_worker) {
261     // Reset the mesh interface if we are not the master.
262     OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuMeshStateInterface>(
263                             rmgr, tpu::kTpuMeshStateInterfaceResourceName));
264     auto* mesh_state_interface = tpu::TpuMeshStateInterface::Create();
265     OP_REQUIRES_OK(ctx, rmgr->Create(rmgr->default_container(),
266                                      tpu::kTpuMeshStateInterfaceResourceName,
267                                      mesh_state_interface));
268   }
269 
270   VLOG(1) << "Removing existing proto compilation cache lookup if it exists";
271   OP_REQUIRES_OK(ctx, DeleteIfExists<tpu::TpuCompilationCacheLookup>(
272                           rmgr, tpu::kCompiledProtoCacheResourceName));
273 
274   if (enable_whole_mesh_compilations_) {
275     // If this is a whole mesh compilation mode, create the compilation cache,
276     // if missing.
277     tpu::TpuCompilationCacheInterface* compilation_cache;
278     OP_REQUIRES_OK(ctx, CreateTpuCompilationCache(rmgr, &compilation_cache));
279     compilation_cache->Unref();
280   }
281 
282   tpu::TpuCompilationCacheInterface* local_compilation_cache;
283   Status s = rmgr->Lookup(rmgr->default_container(),
284                           tpu::kCompilationCacheResourceName,
285                           &local_compilation_cache);
286   if (!s.ok()) {
287     local_compilation_cache = nullptr;
288   }
289 
290   TF_Status* status = TF_NewStatus();
291   size_t device_id_output_size;
292   int32_t* device_id_output = nullptr;
293   auto cleanup = xla::MakeCleanup([&status, &device_id_output]() {
294     TF_DeleteStatus(status);
295     tpu::OpsApiFn()->TpuConfigurationApi_FreeInt32ArrayFn(device_id_output);
296   });
297 
298   InitializeHostForDistributedTpuOp_DoWork_Params params;
299   params.struct_size = InitializeHostForDistributedTpuOp_DoWork_Params_SIZE;
300   params.priv = nullptr;
301   params.tpu_host_config_size = tpu_host_config.size();
302   params.tpu_host_config = tpu_host_config.data();
303   params.enable_whole_mesh_compilations = enable_whole_mesh_compilations_;
304   params.is_master_worker = is_master_worker;
305   params.core_id_output_size = &device_id_output_size;
306   params.core_id_output = &device_id_output;
307   params.status = status;
308 
309   tpu::OpsApiFn()->InitializeHostForDistributedTpuOp_DoWorkFn(&params);
310   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
311 
312   if (local_compilation_cache != nullptr) {
313     local_compilation_cache->Unref();
314 
315     tpu::TpuCompilationCacheLookup* proto_lookup;
316     proto_lookup =
317         new tpu::TpuCompilationCacheLocalLookup(local_compilation_cache);
318     OP_REQUIRES_OK(
319         ctx, rmgr->Create(rmgr->default_container(),
320                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
321   } else {
322     int64_t cache_size_bytes;
323     tpu::OpsApiFn()->TpuConfigurationApi_RemoteCompilationCacheSizeInBytesFn(
324         &cache_size_bytes);
325 
326     char* server_address_output = nullptr;
327     auto cleanup_server_address = xla::MakeCleanup([&server_address_output]() {
328       tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(
329           server_address_output);
330     });
331     size_t server_address_output_size;
332 
333     TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params params;
334     params.struct_size =
335         TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE;
336     params.priv = nullptr;
337     params.tpu_host_config_size = tpu_host_config.size();
338     params.tpu_host_config = tpu_host_config.data();
339     params.server_address_output_size = &server_address_output_size;
340     params.server_address_output = &server_address_output;
341     params.status = status;
342 
343     tpu::OpsApiFn()
344         ->TpuConfigurationApi_CompilationCacheServerAddressFromConfigFn(
345             &params);
346     OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
347 
348     std::string server_address(server_address_output,
349                                server_address_output_size);
350     tpu::TpuCompilationCacheLookup* proto_lookup =
351         new tpu::TpuCompilationCacheRpcLookup(server_address, cache_size_bytes);
352     OP_REQUIRES_OK(
353         ctx, rmgr->Create(rmgr->default_container(),
354                           tpu::kCompiledProtoCacheResourceName, proto_lookup));
355   }
356 
357   Tensor* ctx_output;
358   OP_REQUIRES_OK(
359       ctx, ctx->allocate_output(
360                0, TensorShape({static_cast<long long>(device_id_output_size)}),
361                &ctx_output));
362 
363   for (size_t i = 0; i < device_id_output_size; ++i) {
364     ctx_output->flat<int32>()(i) = device_id_output[i];
365   }
366 
367   VLOG(1) << "InitializeHostForDistributedTpuOp done";
368 }
369 
Compute(OpKernelContext * ctx)370 void SetGlobalTPUArrayOp::Compute(OpKernelContext* ctx) {
371   VLOG(1) << "SetGlobalTPUArrayOp";
372   XLA_SCOPED_LOGGING_TIMER("SetGlobalTPUArrayOp");
373 
374   auto tpu_topology = ctx->input(0).scalar<tstring>()();
375   TF_Status* status = TF_NewStatus();
376 
377   tpu::OpsApiFn()->SetGlobalTPUArrayOp_DoWorkFn(tpu_topology.size(),
378                                                 tpu_topology.data(), status);
379 
380   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
381   TF_DeleteStatus(status);
382 
383   VLOG(1) << "SetGlobalTPUArrayOp done";
384 }
385 
Compute(OpKernelContext * ctx)386 void DisconnectDistributedTpuChipsOp::Compute(OpKernelContext* ctx) {
387   VLOG(1) << "DisconnectDistributedTpuChipsOp";
388   XLA_SCOPED_LOGGING_TIMER("DisconnectDistributedTpuChipsOp");
389 
390   TF_Status* status = TF_NewStatus();
391   int32_t number_of_chips_output = 0;
392 
393   tpu::OpsApiFn()->DisconnectDistributedTpuChipsOp_DoWorkFn(
394       &number_of_chips_output, status);
395 
396   Tensor* ctx_output;
397   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &ctx_output));
398   ctx_output->scalar<int32_t>()() = number_of_chips_output;
399 
400   OP_REQUIRES_OK(ctx, StatusFromTF_Status(status));
401   TF_DeleteStatus(status);
402 
403   VLOG(1) << "DisconnectDistributedTpuChipsOp done";
404 }
405 
406 // These ops execute on the TPU_SYSTEM device only.
407 REGISTER_KERNEL_BUILDER(Name("_ConfigureDistributedTPU")
408                             .Device(DEVICE_TPU_SYSTEM)
409                             .HostMemory("output"),
410                         ConfigureDistributedTpuOp);
411 REGISTER_KERNEL_BUILDER(Name("_WaitForDistributedTPU")
412                             .Device(DEVICE_TPU_SYSTEM)
413                             .HostMemory("inputs")
414                             .HostMemory("topology"),
415                         WaitForDistributedTpuOp);
416 REGISTER_KERNEL_BUILDER(
417     Name("_ShutdownDistributedTPU").Device(DEVICE_TPU_SYSTEM),
418     ShutdownDistributedTpuOp);
419 REGISTER_KERNEL_BUILDER(Name("_InitializeHostForDistributedTPU")
420                             .Device(DEVICE_TPU_SYSTEM)
421                             .HostMemory("input")
422                             .HostMemory("tpu_ids"),
423                         InitializeHostForDistributedTpuOp);
424 REGISTER_KERNEL_BUILDER(
425     Name("_SetGlobalTPUArray").Device(DEVICE_TPU_SYSTEM).HostMemory("topology"),
426     SetGlobalTPUArrayOp);
427 REGISTER_KERNEL_BUILDER(Name("_DisconnectHostFromDistributedTPUSystem")
428                             .Device(DEVICE_TPU_SYSTEM)
429                             .HostMemory("number_of_tpu_chips"),
430                         DisconnectDistributedTpuChipsOp);
431 
432 }  // namespace tensorflow
433