1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // GrpcMasterService implements the RPC service MasterSerivce.
17 //
18 // A GrpcMasterService maintains the state of live graph computation
19 // sessions, each session orchestrates both local and remote devices
20 // to carry out the graph computation.
21 //
22 // A GrpcMasterService knows ahead of time local devices available as
23 // client devices.
24 //
25 // A GrpcMasterService discovers remote devices in the background and
26 // keeps track of statistics of those remote devices.
27 //
28 // Each session analyzes the graph, places nodes across available
29 // devices, and ultimately drives the graph computation by initiating
30 // RunGraph on workers.
31 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
32
33 #include "grpcpp/alarm.h"
34 #include "grpcpp/server_builder.h"
35
36 #include "tensorflow/core/distributed_runtime/master.h"
37 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/platform/logging.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/tracing.h"
44 #include "tensorflow/core/protobuf/master.pb.h"
45
46 namespace tensorflow {
47
48 class GrpcMasterService : public AsyncServiceInterface {
49 public:
GrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)50 GrpcMasterService(Master* master, const ConfigProto& default_session_config,
51 ::grpc::ServerBuilder* builder)
52 : master_impl_(master),
53 is_shutdown_(false),
54 default_session_config_(default_session_config) {
55 builder->RegisterService(&master_service_);
56 cq_ = builder->AddCompletionQueue();
57 }
58
~GrpcMasterService()59 ~GrpcMasterService() override { delete shutdown_alarm_; }
60
Shutdown()61 void Shutdown() override {
62 bool did_shutdown = false;
63 {
64 mutex_lock l(mu_);
65 if (!is_shutdown_) {
66 LOG(INFO) << "Shutting down GrpcMasterService.";
67 is_shutdown_ = true;
68 did_shutdown = true;
69 }
70 }
71 if (did_shutdown) {
72 // NOTE(mrry): This enqueues a special event (with a null tag)
73 // that causes the completion queue to be shut down on the
74 // polling thread.
75 shutdown_alarm_ =
76 new ::grpc::Alarm(cq_.get(), gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
77 }
78 }
79
80 // This macro creates a new request for the given RPC method name
81 // (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
82 // `this->cq_`.
83 //
84 // This macro is invoked one or more times for each RPC method to
85 // ensure that there are sufficient completion queue entries to
86 // handle incoming requests without blocking.
87 //
88 // The implementation of the request handler for each RPC method
89 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
90 // to keep accepting new requests.
91 #define ENQUEUE_REQUEST(method, supports_cancel) \
92 do { \
93 mutex_lock l(mu_); \
94 if (!is_shutdown_) { \
95 Call<GrpcMasterService, grpc::MasterService::AsyncService, \
96 method##Request, method##Response>:: \
97 EnqueueRequest(&master_service_, cq_.get(), \
98 &grpc::MasterService::AsyncService::Request##method, \
99 &GrpcMasterService::method##Handler, \
100 (supports_cancel)); \
101 } \
102 } while (0)
103
HandleRPCsLoop()104 void HandleRPCsLoop() override {
105 ENQUEUE_REQUEST(CreateSession, true);
106 ENQUEUE_REQUEST(ExtendSession, false);
107 for (int i = 0; i < 100; ++i) {
108 ENQUEUE_REQUEST(PartialRunSetup, false);
109 ENQUEUE_REQUEST(RunStep, true);
110 }
111 ENQUEUE_REQUEST(CloseSession, false);
112 ENQUEUE_REQUEST(ListDevices, false);
113 ENQUEUE_REQUEST(Reset, false);
114 ENQUEUE_REQUEST(MakeCallable, false);
115 for (int i = 0; i < 100; ++i) {
116 ENQUEUE_REQUEST(RunCallable, true);
117 }
118 ENQUEUE_REQUEST(ReleaseCallable, false);
119
120 void* tag;
121 bool ok;
122 while (cq_->Next(&tag, &ok)) {
123 UntypedCall<GrpcMasterService>::Tag* callback_tag =
124 static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
125 if (callback_tag) {
126 callback_tag->OnCompleted(this, ok);
127 } else {
128 // NOTE(mrry): A null `callback_tag` indicates that this is
129 // the shutdown alarm.
130 cq_->Shutdown();
131 }
132 }
133 }
134
135 private:
136 Master* master_impl_ = nullptr; // Not owned.
137 std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
138 grpc::MasterService::AsyncService master_service_;
139
140 mutex mu_;
141 bool is_shutdown_ GUARDED_BY(mu_);
142 const ConfigProto default_session_config_;
143 ::grpc::Alarm* shutdown_alarm_ = nullptr;
144
145 template <class RequestMessage, class ResponseMessage>
146 using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
147 RequestMessage, ResponseMessage>;
148
149 // RPC handler for creating a session.
CreateSessionHandler(MasterCall<CreateSessionRequest,CreateSessionResponse> * call)150 void CreateSessionHandler(
151 MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
152 CreateSessionRequest* rewritten_req = new CreateSessionRequest;
153 rewritten_req->mutable_config()->MergeFrom(default_session_config_);
154 rewritten_req->MergeFrom(call->request);
155 master_impl_->CreateSession(rewritten_req, &call->response,
156 [call, rewritten_req](const Status& status) {
157 call->SendResponse(ToGrpcStatus(status));
158 delete rewritten_req;
159 });
160 ENQUEUE_REQUEST(CreateSession, true);
161 }
162
163 // RPC handler for extending a session.
ExtendSessionHandler(MasterCall<ExtendSessionRequest,ExtendSessionResponse> * call)164 void ExtendSessionHandler(
165 MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
166 master_impl_->ExtendSession(&call->request, &call->response,
167 [call](const Status& status) {
168 call->SendResponse(ToGrpcStatus(status));
169 });
170 ENQUEUE_REQUEST(ExtendSession, false);
171 }
172
173 // RPC handler for setting up a partial run call.
PartialRunSetupHandler(MasterCall<PartialRunSetupRequest,PartialRunSetupResponse> * call)174 void PartialRunSetupHandler(
175 MasterCall<PartialRunSetupRequest, PartialRunSetupResponse>* call) {
176 master_impl_->PartialRunSetup(&call->request, &call->response,
177 [call](const Status& status) {
178 call->SendResponse(ToGrpcStatus(status));
179 });
180 ENQUEUE_REQUEST(PartialRunSetup, false);
181 }
182
183 // RPC handler for running one step in a session.
RunStepHandler(MasterCall<RunStepRequest,RunStepResponse> * call)184 void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
185 auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
186 CallOptions* call_opts = new CallOptions;
187 if (call->request.options().timeout_in_ms() > 0) {
188 call_opts->SetTimeout(call->request.options().timeout_in_ms());
189 } else {
190 call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
191 }
192 RunStepRequestWrapper* wrapped_request =
193 new ProtoRunStepRequest(&call->request);
194 MutableRunStepResponseWrapper* wrapped_response =
195 new NonOwnedProtoRunStepResponse(&call->response);
196 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
197 master_impl_->RunStep(
198 call_opts, wrapped_request, wrapped_response,
199 [call, call_opts, wrapped_request, wrapped_response,
200 trace](const Status& status) {
201 call->ClearCancelCallback();
202 delete call_opts;
203 delete wrapped_request;
204 delete trace;
205 if (call->request.store_errors_in_response_body() && !status.ok()) {
206 call->response.set_status_code(status.code());
207 call->response.set_status_error_message(status.error_message());
208 call->SendResponse(ToGrpcStatus(Status::OK()));
209 } else {
210 call->SendResponse(ToGrpcStatus(status));
211 }
212 });
213 ENQUEUE_REQUEST(RunStep, true);
214 }
215
216 // RPC handler for deleting a session.
CloseSessionHandler(MasterCall<CloseSessionRequest,CloseSessionResponse> * call)217 void CloseSessionHandler(
218 MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
219 master_impl_->CloseSession(&call->request, &call->response,
220 [call](const Status& status) {
221 call->SendResponse(ToGrpcStatus(status));
222 });
223 ENQUEUE_REQUEST(CloseSession, false);
224 }
225
226 // RPC handler for listing devices.
ListDevicesHandler(MasterCall<ListDevicesRequest,ListDevicesResponse> * call)227 void ListDevicesHandler(
228 MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
229 master_impl_->ListDevices(&call->request, &call->response,
230 [call](const Status& status) {
231 call->SendResponse(ToGrpcStatus(status));
232 });
233 ENQUEUE_REQUEST(ListDevices, false);
234 }
235
236 // RPC handler for resetting all sessions.
ResetHandler(MasterCall<ResetRequest,ResetResponse> * call)237 void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
238 master_impl_->Reset(&call->request, &call->response,
239 [call](const Status& status) {
240 call->SendResponse(ToGrpcStatus(status));
241 });
242 ENQUEUE_REQUEST(Reset, false);
243 }
244
245 // RPC handler for making a callable.
MakeCallableHandler(MasterCall<MakeCallableRequest,MakeCallableResponse> * call)246 void MakeCallableHandler(
247 MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
248 master_impl_->MakeCallable(&call->request, &call->response,
249 [call](const Status& status) {
250 call->SendResponse(ToGrpcStatus(status));
251 });
252 ENQUEUE_REQUEST(MakeCallable, false);
253 }
254
255 // RPC handler for running a callable.
RunCallableHandler(MasterCall<RunCallableRequest,RunCallableResponse> * call)256 void RunCallableHandler(
257 MasterCall<RunCallableRequest, RunCallableResponse>* call) {
258 auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
259 CallOptions* call_opts = new CallOptions;
260 // The timeout may be overridden by a non-zero timeout in the
261 // callable's `RunOptions`; this overriding will happen inside the
262 // `MasterSession` implementation.
263 call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
264 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
265 master_impl_->RunCallable(call_opts, &call->request, &call->response,
266 [call, call_opts, trace](const Status& status) {
267 call->ClearCancelCallback();
268 delete call_opts;
269 delete trace;
270 call->SendResponse(ToGrpcStatus(status));
271 });
272 ENQUEUE_REQUEST(RunCallable, false);
273 }
274
275 // RPC handler for making a callable.
ReleaseCallableHandler(MasterCall<ReleaseCallableRequest,ReleaseCallableResponse> * call)276 void ReleaseCallableHandler(
277 MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
278 master_impl_->ReleaseCallable(&call->request, &call->response,
279 [call](const Status& status) {
280 call->SendResponse(ToGrpcStatus(status));
281 });
282 ENQUEUE_REQUEST(ReleaseCallable, false);
283 }
284
285 #undef ENQUEUE_REQUEST
286
287 // Start tracing, including the ID attached to the RPC.
TraceRpc(StringPiece name,const std::multimap<::grpc::string_ref,::grpc::string_ref> & metadata)288 tracing::ScopedActivity* TraceRpc(
289 StringPiece name,
290 const std::multimap<::grpc::string_ref, ::grpc::string_ref>& metadata) {
291 StringPiece id;
292 auto it = metadata.find(GrpcIdKey());
293 if (it != metadata.end()) {
294 id = StringPiece(it->second.data(), it->second.size());
295 }
296 return new tracing::ScopedActivity(name, id);
297 }
298
299 TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
300 };
301
NewGrpcMasterService(Master * master,const ConfigProto & default_session_config,::grpc::ServerBuilder * builder)302 AsyncServiceInterface* NewGrpcMasterService(
303 Master* master, const ConfigProto& default_session_config,
304 ::grpc::ServerBuilder* builder) {
305 return new GrpcMasterService(master, default_session_config, builder);
306 }
307
308 } // end namespace tensorflow
309