• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
16 
17 #include <chrono>  // NOLINT
18 
19 #include "grpcpp/support/byte_buffer.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
21 #include "tensorflow/core/platform/coding.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
23 
24 namespace tensorflow {
25 namespace {
26 using ::tensorflow::tpu::CompilationCacheEntryRef;
27 using ::tensorflow::tpu::TpuCompilationCacheEntry;
28 using ::tensorflow::tpu::TpuCompilationCacheInterface;
29 
30 static constexpr int kGetTpuProgramServingThreads = 32;
31 }  // namespace
32 
TpuCompilationCacheService(::grpc::ServerBuilder * server_builder,TpuCompilationCacheInterface * cache)33 TpuCompilationCacheService::TpuCompilationCacheService(
34     ::grpc::ServerBuilder* server_builder, TpuCompilationCacheInterface* cache)
35     : running_(true),
36       cache_(cache),
37       server_builder_(server_builder),
38       cq_(server_builder_->AddCompletionQueue()),
39       thread_pool_(absl::make_unique<thread::ThreadPool>(
40           Env::Default(), "TpuCompilationCacheService",
41           kGetTpuProgramServingThreads)) {
42   cache_->Ref();
43   server_builder_->RegisterService(&service_);
44 }
45 
~TpuCompilationCacheService()46 TpuCompilationCacheService::~TpuCompilationCacheService() {
47   // This ordering is important. We must first shutdown our CQ and allow the
48   // polling thread and dispatch pool to shutdown before releasing our cache
49   // reference. The gRPC server must be Shutdown() by this point or we will
50   // deadlock here.  The running_ boolean is necessary to avoid adding new
51   // operations to the CQ after is has shutdown.
52   running_ = false;
53   cq_->Shutdown();
54   polling_thread_.reset();
55   thread_pool_.reset();
56   cache_->Unref();
57 }
58 
Start()59 void TpuCompilationCacheService::Start() {
60   server_ = server_builder_->BuildAndStart();
61   ThreadOptions opts;
62   polling_thread_.reset(Env::Default()->StartThread(
63       opts, "TpuCompilationCachePoller", [this]() { HandleRPCsLoop(); }));
64 }
65 
Shutdown(int timeout_sec)66 bool TpuCompilationCacheService::Shutdown(int timeout_sec) {
67   if (server_ != nullptr) {
68     std::chrono::system_clock::time_point timeout =
69         std::chrono::system_clock::now() + std::chrono::seconds(timeout_sec);
70     server_->Shutdown(std::chrono::system_clock::now() +
71                       std::chrono::seconds(timeout_sec));
72     if (std::chrono::system_clock::now() >= timeout) {
73       return false;
74     }
75     return true;
76   } else {
77     return false;
78   }
79 }
80 
SetMemoryQuota(size_t max_bytes)81 void TpuCompilationCacheService::SetMemoryQuota(size_t max_bytes) {
82   ::grpc::ResourceQuota quota;
83   quota.Resize(max_bytes);
84   server_builder_->SetResourceQuota(quota);
85 }
86 
87 // Fetch a cache result for the given request and serialize the result directly
88 // into a ByteBuffer.
GetTpuProgram(GetTpuProgramCall * call)89 void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) {
90   std::unique_ptr<CompilationCacheEntryRef> entry;
91 
92   VLOG(1) << "GetTpuProgram: " << call->request.DebugString();
93   Status s;
94   switch (call->request.key_oneof_case()) {
95     case tpu::GetTpuProgramRequest::kKey:
96       s = cache_->Lookup(call->request.key(), &entry);
97       break;
98 
99     case tpu::GetTpuProgramRequest::kUidAndIndex:
100       s = cache_->Lookup(call->request.uid_and_index().uid(),
101                          call->request.uid_and_index().proto_index(), &entry);
102       break;
103 
104     default:
105       s = errors::Internal("Bad GetTpuProgram RPC request oneof case ",
106                            call->request.key_oneof_case());
107       break;
108   }
109   if (!s.ok()) {
110     return call->SendResponse(ToGrpcStatus(s));
111   }
112 
113   s = entry->ToSubEntryRef(call->request.fetch_target());
114   if (!s.ok()) {
115     return call->SendResponse(::grpc::Status(
116         ::grpc::StatusCode::INVALID_ARGUMENT,
117         absl::StrCat(
118             "Error getting the fetching target ",
119             CompilationCacheFetchTarget_Name(call->request.fetch_target())),
120         s.error_message()));
121   }
122 
123   TpuCompilationCacheEntry cache_entry = entry->get();
124   if (cache_entry.tpu_program_group() == nullptr) {
125     // It's possible that the sharding/unsharding entry does not exist, but the
126     // main entry must exist.
127     CHECK_NE(call->request.fetch_target(),
128              tpu::CompilationCacheFetchTarget::MAIN);
129   }
130 
131   xla::StatusOr<std::vector<::grpc::Slice>> buffer_slices =
132       tpu::SerializeCacheEntryToBufferSlices(cache_entry);
133 
134   if (!buffer_slices.ok()) {
135     return call->SendResponse(ToGrpcStatus(buffer_slices.status()));
136   }
137 
138   call->response =
139       ::grpc::ByteBuffer{&buffer_slices.ValueOrDie()[0], buffer_slices->size()};
140   return call->SendResponse(::grpc::Status());
141 }
142 
HandleGetTpuProgram(GetTpuProgramCall * call)143 void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {
144   thread_pool_->Schedule([this, call]() { GetTpuProgram(call); });
145   if (running_) {
146     GetTpuProgramCall::EnqueueRequestForMethod(
147         &service_, cq_.get(),
148         static_cast<int>(ServiceType::MethodId::kGetTpuProgram),
149         &TpuCompilationCacheService::HandleGetTpuProgram,
150         /*supports_cancel=*/false);
151   }
152 }
153 
HandleRPCsLoop()154 void TpuCompilationCacheService::HandleRPCsLoop() {
155   void* tag;
156   bool ok;
157 
158   for (int i = 0; i < 50; ++i) {
159     GetTpuProgramCall::EnqueueRequestForMethod(
160         &service_, cq_.get(),
161         static_cast<int>(ServiceType::MethodId::kGetTpuProgram),
162         &TpuCompilationCacheService::HandleGetTpuProgram,
163         /*supports_cancel=*/false);
164   }
165 
166   while (cq_->Next(&tag, &ok)) {
167     VLOG(2) << "HandleRPCS: " << tag;
168     UntypedCall<TpuCompilationCacheService>::Tag* callback_tag =
169         static_cast<UntypedCall<TpuCompilationCacheService>::Tag*>(tag);
170     callback_tag->OnCompleted(this, ok);
171   }
172 
173   VLOG(2) << "Cache thread shutting down.";
174 }
175 }  // namespace tensorflow
176