• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_pod_state.h"
16 
17 #include "tensorflow/c/tf_status.h"
18 #include "tensorflow/c/tf_status_helper.h"
19 #include "tensorflow/core/tpu/tpu_api.h"
20 
21 #if defined(LIBTPU_ON_GCE)
22 #include "tensorflow/core/tpu/kernels/tpu_util.h"
23 #else
24 #include "tensorflow/core/tpu/kernels/tpu_util.h"  // copybara"
25 #endif
26 
27 namespace tensorflow {
28 const char kTpuPodStateResourceName[] = "tpu_pod_state";
29 
30 namespace {
31 
32 // Attempt to delete resource_name from resource_manager's default_container.
33 // Returns OK if the deletion succeeded, or if the resource was not found. Else
34 // return the deletion error.
35 template <class ResourceT>
DeleteIfExists(ResourceMgr * resource_manager,const char * resource_name)36 Status DeleteIfExists(ResourceMgr* resource_manager,
37                       const char* resource_name) {
38   VLOG(1) << "Removing resource " << resource_name << " if it exists";
39   Status status = resource_manager->Delete<ResourceT>(
40       resource_manager->default_container(), resource_name);
41   if (status.ok()) {
42     VLOG(1) << "Removed existing resource " << resource_name;
43     return Status::OK();
44   }
45   if (status.code() == error::NOT_FOUND) {
46     VLOG(1) << "No resource " << resource_name << " to remove";
47     return Status::OK();
48   }
49   VLOG(1) << "Error removing resource " << resource_name << " : " << status;
50   return status;
51 }
52 
53 xla::StatusOr<std::unique_ptr<TpuCompilationCacheService>>
ConstructCacheService(ResourceMgr * rmgr,int serving_port,tpu::TpuCompilationCacheInterface * compilation_cache)54 ConstructCacheService(ResourceMgr* rmgr, int serving_port,
55                       tpu::TpuCompilationCacheInterface* compilation_cache) {
56   xla::StatusOr<std::unique_ptr<::grpc::ServerBuilder>> server_builder;
57 #if defined(LIBTPU_ON_GCE)
58   server_builder = tpu::CreateServerBuilder(serving_port);
59 #else
60   server_builder = tpu::CreateServerBuilderGoogle(serving_port);
61 #endif
62   TF_RETURN_IF_ERROR(server_builder.status());
63 
64   auto cache_service = absl::make_unique<TpuCompilationCacheService>(
65       server_builder.ValueOrDie().get(), compilation_cache);
66   cache_service->SetMemoryQuota(1ul << 31);  // 2GB
67   cache_service->Start();
68   return cache_service;
69 }
70 }  // namespace
71 
GetServerAddressAndPort(std::string * server_address,int * serving_port)72 Status GetServerAddressAndPort(std::string* server_address, int* serving_port) {
73   TF_Status* status = TF_NewStatus();
74   char* server_address_output = nullptr;
75   auto cleanup = xla::MakeCleanup([&status, &server_address_output]() {
76     TF_DeleteStatus(status);
77     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(server_address_output);
78   });
79   size_t server_address_output_size;
80   *serving_port = -1;
81 
82   TpuConfigurationApi_GetServerAddressAndPort_Params params;
83   params.struct_size = TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE;
84   params.priv = nullptr;
85   params.server_address_output_size = &server_address_output_size;
86   params.server_address_output = &server_address_output;
87   params.port_output = serving_port;
88   params.status = status;
89 
90   tpu::OpsApiFn()->TpuConfigurationApi_GetServerAddressAndPortFn(&params);
91   TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
92   *server_address =
93       std::string(server_address_output, server_address_output_size);
94   CHECK_NE(*serving_port, -1);
95   return Status::OK();
96 }
97 
TpuPodState(int service_port,std::unique_ptr<TpuCompilationCacheService> cache_service)98 TpuPodState::TpuPodState(
99     int service_port, std::unique_ptr<TpuCompilationCacheService> cache_service)
100     : cache_service_(std::move(cache_service)), service_port_(service_port) {}
101 
~TpuPodState()102 TpuPodState::~TpuPodState() {
103   if (cache_service_) {
104     VLOG(1) << "Shutting down Compilation Cache Service.";
105     if (cache_service_->Shutdown(20)) {
106       if (service_port_ >= 0) {
107         tpu::OpsApiFn()->TpuNetUtil_RecycleUnusedPortFn(service_port_);
108       }
109     } else {
110       LOG(ERROR)
111           << "Failed to shutdown Compilation Cache Service within timeout.";
112     }
113   }
114   VLOG(1) << "Shutting down Compilation Cache Service done.";
115 }
116 
DebugString() const117 string TpuPodState::DebugString() const {
118   return "Wrapper for distributed TPU state";
119 }
120 
GetTPUPodState(const ResourceMgr * rmgr,TpuPodState ** pod_state)121 Status GetTPUPodState(const ResourceMgr* rmgr, TpuPodState** pod_state) {
122   if (!rmgr) {
123     return errors::Internal("No resource manager.");
124   }
125   if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName,
126                     pod_state)
127            .ok()) {
128     return errors::FailedPrecondition(
129         "The TPU system has not been initialized.");
130   }
131   return Status::OK();
132 }
133 
HasTPUPodState(const ResourceMgr * rmgr)134 bool HasTPUPodState(const ResourceMgr* rmgr) {
135   TpuPodState* pod_state;
136   if (!rmgr->Lookup(rmgr->default_container(), kTpuPodStateResourceName,
137                     &pod_state)
138            .ok()) {
139     return false;
140   }
141   pod_state->Unref();
142   return true;
143 }
144 
ConstructTpuPodState(ResourceMgr * rmgr,const std::vector<int32_t> & num_devices_per_host,tpu::TpuCompilationCacheInterface * compilation_cache,std::string * host_config_proto)145 Status ConstructTpuPodState(
146     ResourceMgr* rmgr, const std::vector<int32_t>& num_devices_per_host,
147     tpu::TpuCompilationCacheInterface* compilation_cache,
148     std::string* host_config_proto) {
149   TF_Status* status = TF_NewStatus();
150   auto status_cleanup =
151       xla::MakeCleanup([&status]() { TF_DeleteStatus(status); });
152 
153   int serving_port;
154   std::string server_address;
155   TF_RETURN_IF_ERROR(GetServerAddressAndPort(&server_address, &serving_port));
156 
157   char* host_config_output = nullptr;
158   auto host_config_cleanup = xla::MakeCleanup([&host_config_output]() {
159     tpu::OpsApiFn()->TpuConfigurationApi_FreeCharArrayFn(host_config_output);
160   });
161   size_t host_config_output_size;
162 
163   ConfigureDistributedTpuOp_DoWork_Params params;
164   params.struct_size = ConfigureDistributedTpuOp_DoWork_Params_SIZE;
165   params.priv = nullptr;
166   params.num_cores_per_host_size = num_devices_per_host.size();
167   params.num_cores_per_host = num_devices_per_host.data();
168   params.server_address_size = server_address.size();
169   params.server_address = server_address.data();
170   params.host_config_output_size = &host_config_output_size;
171   params.host_config_output = &host_config_output;
172   params.status = status;
173 
174   tpu::OpsApiFn()->ConfigureDistributedTpuOp_DoWorkFn(&params);
175   TF_RETURN_IF_ERROR(StatusFromTF_Status(status));
176   *host_config_proto = std::string(host_config_output, host_config_output_size);
177 
178   TF_ASSIGN_OR_RETURN(
179       std::unique_ptr<TpuCompilationCacheService> cache_service,
180       ConstructCacheService(rmgr, serving_port, compilation_cache));
181 
182   // Delete TpuPodState if it exists, and recreate below.
183   TF_RETURN_IF_ERROR(
184       DeleteIfExists<TpuPodState>(rmgr, kTpuPodStateResourceName));
185   return rmgr->Create(rmgr->default_container(), kTpuPodStateResourceName,
186                       new TpuPodState(serving_port, std::move(cache_service)));
187 }
188 }  // namespace tensorflow
189