1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_service.h"
16
17 #include <chrono> // NOLINT
18
19 #include "grpcpp/support/byte_buffer.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
21 #include "tensorflow/core/platform/coding.h"
22 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
23
24 namespace tensorflow {
25 namespace {
26 using ::tensorflow::tpu::CompilationCacheEntryRef;
27 using ::tensorflow::tpu::TpuCompilationCacheEntry;
28 using ::tensorflow::tpu::TpuCompilationCacheInterface;
29
30 static constexpr int kGetTpuProgramServingThreads = 32;
31 } // namespace
32
TpuCompilationCacheService(::grpc::ServerBuilder * server_builder,TpuCompilationCacheInterface * cache)33 TpuCompilationCacheService::TpuCompilationCacheService(
34 ::grpc::ServerBuilder* server_builder, TpuCompilationCacheInterface* cache)
35 : running_(true),
36 cache_(cache),
37 server_builder_(server_builder),
38 cq_(server_builder_->AddCompletionQueue()),
39 thread_pool_(absl::make_unique<thread::ThreadPool>(
40 Env::Default(), "TpuCompilationCacheService",
41 kGetTpuProgramServingThreads)) {
42 cache_->Ref();
43 server_builder_->RegisterService(&service_);
44 }
45
~TpuCompilationCacheService()46 TpuCompilationCacheService::~TpuCompilationCacheService() {
47 // This ordering is important. We must first shutdown our CQ and allow the
48 // polling thread and dispatch pool to shutdown before releasing our cache
49 // reference. The gRPC server must be Shutdown() by this point or we will
50 // deadlock here. The running_ boolean is necessary to avoid adding new
51 // operations to the CQ after is has shutdown.
52 running_ = false;
53 cq_->Shutdown();
54 polling_thread_.reset();
55 thread_pool_.reset();
56 cache_->Unref();
57 }
58
Start()59 void TpuCompilationCacheService::Start() {
60 server_ = server_builder_->BuildAndStart();
61 ThreadOptions opts;
62 polling_thread_.reset(Env::Default()->StartThread(
63 opts, "TpuCompilationCachePoller", [this]() { HandleRPCsLoop(); }));
64 }
65
Shutdown(int timeout_sec)66 bool TpuCompilationCacheService::Shutdown(int timeout_sec) {
67 if (server_ != nullptr) {
68 std::chrono::system_clock::time_point timeout =
69 std::chrono::system_clock::now() + std::chrono::seconds(timeout_sec);
70 server_->Shutdown(std::chrono::system_clock::now() +
71 std::chrono::seconds(timeout_sec));
72 if (std::chrono::system_clock::now() >= timeout) {
73 return false;
74 }
75 return true;
76 } else {
77 return false;
78 }
79 }
80
SetMemoryQuota(size_t max_bytes)81 void TpuCompilationCacheService::SetMemoryQuota(size_t max_bytes) {
82 ::grpc::ResourceQuota quota;
83 quota.Resize(max_bytes);
84 server_builder_->SetResourceQuota(quota);
85 }
86
87 // Fetch a cache result for the given request and serialize the result directly
88 // into a ByteBuffer.
GetTpuProgram(GetTpuProgramCall * call)89 void TpuCompilationCacheService::GetTpuProgram(GetTpuProgramCall* call) {
90 std::unique_ptr<CompilationCacheEntryRef> entry;
91
92 VLOG(1) << "GetTpuProgram: " << call->request.DebugString();
93 Status s;
94 switch (call->request.key_oneof_case()) {
95 case tpu::GetTpuProgramRequest::kKey:
96 s = cache_->Lookup(call->request.key(), &entry);
97 break;
98
99 case tpu::GetTpuProgramRequest::kUidAndIndex:
100 s = cache_->Lookup(call->request.uid_and_index().uid(),
101 call->request.uid_and_index().proto_index(), &entry);
102 break;
103
104 default:
105 s = errors::Internal("Bad GetTpuProgram RPC request oneof case ",
106 call->request.key_oneof_case());
107 break;
108 }
109 if (!s.ok()) {
110 return call->SendResponse(ToGrpcStatus(s));
111 }
112
113 s = entry->ToSubEntryRef(call->request.fetch_target());
114 if (!s.ok()) {
115 return call->SendResponse(::grpc::Status(
116 ::grpc::StatusCode::INVALID_ARGUMENT,
117 absl::StrCat(
118 "Error getting the fetching target ",
119 CompilationCacheFetchTarget_Name(call->request.fetch_target())),
120 s.error_message()));
121 }
122
123 TpuCompilationCacheEntry cache_entry = entry->get();
124 if (cache_entry.tpu_program_group() == nullptr) {
125 // It's possible that the sharding/unsharding entry does not exist, but the
126 // main entry must exist.
127 CHECK_NE(call->request.fetch_target(),
128 tpu::CompilationCacheFetchTarget::MAIN);
129 }
130
131 xla::StatusOr<std::vector<::grpc::Slice>> buffer_slices =
132 tpu::SerializeCacheEntryToBufferSlices(cache_entry);
133
134 if (!buffer_slices.ok()) {
135 return call->SendResponse(ToGrpcStatus(buffer_slices.status()));
136 }
137
138 call->response =
139 ::grpc::ByteBuffer{&buffer_slices.ValueOrDie()[0], buffer_slices->size()};
140 return call->SendResponse(::grpc::Status());
141 }
142
HandleGetTpuProgram(GetTpuProgramCall * call)143 void TpuCompilationCacheService::HandleGetTpuProgram(GetTpuProgramCall* call) {
144 thread_pool_->Schedule([this, call]() { GetTpuProgram(call); });
145 if (running_) {
146 GetTpuProgramCall::EnqueueRequestForMethod(
147 &service_, cq_.get(),
148 static_cast<int>(ServiceType::MethodId::kGetTpuProgram),
149 &TpuCompilationCacheService::HandleGetTpuProgram,
150 /*supports_cancel=*/false);
151 }
152 }
153
HandleRPCsLoop()154 void TpuCompilationCacheService::HandleRPCsLoop() {
155 void* tag;
156 bool ok;
157
158 for (int i = 0; i < 50; ++i) {
159 GetTpuProgramCall::EnqueueRequestForMethod(
160 &service_, cq_.get(),
161 static_cast<int>(ServiceType::MethodId::kGetTpuProgram),
162 &TpuCompilationCacheService::HandleGetTpuProgram,
163 /*supports_cancel=*/false);
164 }
165
166 while (cq_->Next(&tag, &ok)) {
167 VLOG(2) << "HandleRPCS: " << tag;
168 UntypedCall<TpuCompilationCacheService>::Tag* callback_tag =
169 static_cast<UntypedCall<TpuCompilationCacheService>::Tag*>(tag);
170 callback_tag->OnCompleted(this, ok);
171 }
172
173 VLOG(2) << "Cache thread shutting down.";
174 }
175 } // namespace tensorflow
176