1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // GrpcMasterService implements the RPC service MasterSerivce.
17 //
18 // A GrpcMasterService maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A GrpcMasterService knows ahead of time local devices available as
23 // client devices.
24 //
25 // A GrpcMasterService discovers remote devices in the background and
26 // keeps track of statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on workers.
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
32
33 #include "grpcpp/alarm.h"
34 #include "grpcpp/server_builder.h"
35 #include "tensorflow/core/distributed_runtime/master.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/tracing.h"
43 #include "tensorflow/core/profiler/lib/traceme.h"
44 #include "tensorflow/core/protobuf/master.pb.h"
45
46 namespace tensorflow {
47
48 class GrpcMasterService : public AsyncServiceInterface {
49 public:
GrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)50 GrpcMasterService(Master* master, const ConfigProto& default_session_config,
51 ::grpc::ServerBuilder* builder)
52 : master_impl_(master),
53 is_shutdown_(false),
54 default_session_config_(default_session_config) {
55 builder->RegisterService(&master_service_);
56 cq_ = builder->AddCompletionQueue();
57 }
58
~GrpcMasterService()59 ~GrpcMasterService() override { delete shutdown_alarm_; }
60
Shutdown()61 void Shutdown() override {
62 bool did_shutdown = false;
63 {
64 mutex_lock l(mu_);
65 if (!is_shutdown_) {
66 LOG(INFO) << "Shutting down GrpcMasterService.";
67 is_shutdown_ = true;
68 did_shutdown = true;
69 }
70 }
71 if (did_shutdown) {
72 // NOTE(mrry): This enqueues a special event (with a null tag)
73 // that causes the completion queue to be shut down on the
74 // polling thread.
75 shutdown_alarm_ =
76 new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
77 }
78 }
79
80 // This macro creates a new request for the given RPC method name
81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
82 // `this->cq_`.
83 //
84 // This macro is invoked one or more times for each RPC method to
85 // ensure that there are sufficient completion queue entries to
86 // handle incoming requests without blocking.
87 //
88 // The implementation of the request handler for each RPC method
89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
90 // to keep accepting new requests.
91 #define ENQUEUE_REQUEST(method, supports_cancel) \
92 do { \
93 mutex_lock l(mu_); \
94 if (!is_shutdown_) { \
95 Call<GrpcMasterService, grpc::MasterService::AsyncService, \
96 method##Request, method##Response>:: \
97 EnqueueRequest(&master_service_, cq_.get(), \
98 &grpc::MasterService::AsyncService::Request##method, \
99 &GrpcMasterService::method##Handler, \
100 (supports_cancel)); \
101 } \
102 } while (0)
103
HandleRPCsLoop()104 void HandleRPCsLoop() override {
105 ENQUEUE_REQUEST(CreateSession, true);
106 ENQUEUE_REQUEST(ExtendSession, false);
107 for (int i = 0; i < 100; ++i) {
108 ENQUEUE_REQUEST(PartialRunSetup, false);
109 ENQUEUE_REQUEST(RunStep, true);
110 }
111 ENQUEUE_REQUEST(CloseSession, false);
112 ENQUEUE_REQUEST(ListDevices, false);
113 ENQUEUE_REQUEST(Reset, false);
114 ENQUEUE_REQUEST(MakeCallable, false);
115 for (int i = 0; i < 100; ++i) {
116 ENQUEUE_REQUEST(RunCallable, true);
117 }
118 ENQUEUE_REQUEST(ReleaseCallable, false);
119
120 void* tag;
121 bool ok;
122 while (cq_->Next(&tag, &ok)) {
123 UntypedCall<GrpcMasterService>::Tag* callback_tag =
124 static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
125 if (callback_tag) {
126 callback_tag->OnCompleted(this, ok);
127 } else {
128 // NOTE(mrry): A null `callback_tag` indicates that this is
129 // the shutdown alarm.
130 cq_->Shutdown();
131 }
132 }
133 }
134
135 private:
136 Master* master_impl_ = nullptr; // Not owned.
137 std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
138 grpc::MasterService::AsyncService master_service_;
139
140 mutex mu_;
141 bool is_shutdown_ GUARDED_BY(mu_);
142 const ConfigProto default_session_config_;
143 ::grpc::Alarm* shutdown_alarm_ = nullptr;
144
145 template <class RequestMessage, class ResponseMessage>
146 using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
147 RequestMessage, ResponseMessage>;
148
149 // RPC handler for creating a session.
CreateSessionHandler(MasterCall<CreateSessionRequest,CreateSessionResponse> * call)150 void CreateSessionHandler(
151 MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
152 CreateSessionRequest* rewritten_req = new CreateSessionRequest;
153 rewritten_req->mutable_config()->MergeFrom(default_session_config_);
154 rewritten_req->MergeFrom(call->request);
155 master_impl_->CreateSession(rewritten_req, &call->response,
156 [call, rewritten_req](const Status& status) {
157 call->SendResponse(ToGrpcStatus(status));
158 delete rewritten_req;
159 });
160 ENQUEUE_REQUEST(CreateSession, true);
161 }
162
163 // RPC handler for extending a session.
ExtendSessionHandler(MasterCall<ExtendSessionRequest,ExtendSessionResponse> * call)164 void ExtendSessionHandler(
165 MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
166 master_impl_->ExtendSession(&call->request, &call->response,
167 [call](const Status& status) {
168 call->SendResponse(ToGrpcStatus(status));
169 });
170 ENQUEUE_REQUEST(ExtendSession, false);
171 }
172
173 // RPC handler for setting up a partial run call.
PartialRunSetupHandler(MasterCall<PartialRunSetupRequest,PartialRunSetupResponse> * call)174 void PartialRunSetupHandler(
175 MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
176 master_impl_->PartialRunSetup(&call->request, &call->response,
177 [call](const Status& status) {
178 call->SendResponse(ToGrpcStatus(status));
179 });
180 ENQUEUE_REQUEST(PartialRunSetup, false);
181 }
182
183 // RPC handler for running one step in a session.
RunStepHandler(MasterCall<RunStepRequest,RunStepResponse> * call)184 void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
185 auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
186 CallOptions* call_opts = new CallOptions;
187 if (call->request.options().timeout_in_ms() > 0) {
188 call_opts->SetTimeout(call->request.options().timeout_in_ms());
189 } else {
190 call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
191 }
192 RunStepRequestWrapper* wrapped_request =
193 new ProtoRunStepRequest(&call->request);
194 MutableRunStepResponseWrapper* wrapped_response =
195 new NonOwnedProtoRunStepResponse(&call->response);
196 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
197 master_impl_->RunStep(
198 call_opts, wrapped_request, wrapped_response,
199 [call, call_opts, wrapped_request, trace](const Status& status) {
200 call->ClearCancelCallback();
201 delete call_opts;
202 delete wrapped_request;
203 delete trace;
204 if (call->request.store_errors_in_response_body() && !status.ok()) {
205 call->response.set_status_code(status.code());
206 call->response.set_status_error_message(status.error_message());
207 call->SendResponse(ToGrpcStatus(Status::OK()));
208 } else {
209 call->SendResponse(ToGrpcStatus(status));
210 }
211 });
212 ENQUEUE_REQUEST(RunStep, true);
213 }
214
215 // RPC handler for deleting a session.
CloseSessionHandler(MasterCall<CloseSessionRequest,CloseSessionResponse> * call)216 void CloseSessionHandler(
217 MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
218 master_impl_->CloseSession(&call->request, &call->response,
219 [call](const Status& status) {
220 call->SendResponse(ToGrpcStatus(status));
221 });
222 ENQUEUE_REQUEST(CloseSession, false);
223 }
224
225 // RPC handler for listing devices.
ListDevicesHandler(MasterCall<ListDevicesRequest,ListDevicesResponse> * call)226 void ListDevicesHandler(
227 MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
228 master_impl_->ListDevices(&call->request, &call->response,
229 [call](const Status& status) {
230 call->SendResponse(ToGrpcStatus(status));
231 });
232 ENQUEUE_REQUEST(ListDevices, false);
233 }
234
235 // RPC handler for resetting all sessions.
ResetHandler(MasterCall<ResetRequest,ResetResponse> * call)236 void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
237 master_impl_->Reset(&call->request, &call->response,
238 [call](const Status& status) {
239 call->SendResponse(ToGrpcStatus(status));
240 });
241 ENQUEUE_REQUEST(Reset, false);
242 }
243
244 // RPC handler for making a callable.
MakeCallableHandler(MasterCall<MakeCallableRequest,MakeCallableResponse> * call)245 void MakeCallableHandler(
246 MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
247 master_impl_->MakeCallable(&call->request, &call->response,
248 [call](const Status& status) {
249 call->SendResponse(ToGrpcStatus(status));
250 });
251 ENQUEUE_REQUEST(MakeCallable, false);
252 }
253
254 // RPC handler for running a callable.
RunCallableHandler(MasterCall<RunCallableRequest,RunCallableResponse> * call)255 void RunCallableHandler(
256 MasterCall<RunCallableRequest, RunCallableResponse>* call) {
257 auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
258 CallOptions* call_opts = new CallOptions;
259 // The timeout may be overridden by a non-zero timeout in the
260 // callable's `RunOptions`; this overriding will happen inside the
261 // `MasterSession` implementation.
262 call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
263 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
264 master_impl_->RunCallable(call_opts, &call->request, &call->response,
265 [call, call_opts, trace](const Status& status) {
266 call->ClearCancelCallback();
267 delete call_opts;
268 delete trace;
269 call->SendResponse(ToGrpcStatus(status));
270 });
271 ENQUEUE_REQUEST(RunCallable, false);
272 }
273
274 // RPC handler for making a callable.
ReleaseCallableHandler(MasterCall<ReleaseCallableRequest,ReleaseCallableResponse> * call)275 void ReleaseCallableHandler(
276 MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
277 master_impl_->ReleaseCallable(&call->request, &call->response,
278 [call](const Status& status) {
279 call->SendResponse(ToGrpcStatus(status));
280 });
281 ENQUEUE_REQUEST(ReleaseCallable, false);
282 }
283
284 #undef ENQUEUE_REQUEST
285
286 // Start tracing, including the ID attached to the RPC.
TraceRpc(StringPiece name,const std::multimap<::grpc::string_ref,::grpc::string_ref> & metadata)287 profiler::TraceMe* TraceRpc(
288 StringPiece name,
289 const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
290 StringPiece id;
291 auto it = metadata.find(GrpcIdKey());
292 if (it != metadata.end()) {
293 id = StringPiece(it->second.data(), it->second.size());
294 }
295 return new profiler::TraceMe([&] { return strings::StrCat(name, ":", id); },
296 profiler::TraceMeLevel::kInfo);
297 }
298
299 TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
300 };
301
NewGrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)302 AsyncServiceInterface* NewGrpcMasterService(
303 Master* master, const ConfigProto& default_session_config,
304 ::grpc::ServerBuilder* builder) {
305 return new GrpcMasterService(master, default_session_config, builder);
306 }
307
308 } // end namespace tensorflow
309