1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
17
18 #include <utility>
19
20 #include "grpcpp/generic/generic_stub.h"
21 #include "grpcpp/grpcpp.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/tracing.h"
37 #include "tensorflow/core/protobuf/transport_options.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 #include "tensorflow/core/util/env_var.h"
40
41 namespace tensorflow {
42
43 class GrpcRemoteWorker : public WorkerInterface {
44 public:
GrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)45 explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
46 ::grpc::CompletionQueue* completion_queue,
47 thread::ThreadPool* callback_threadpool,
48 WorkerCacheLogger* logger, const string& target)
49 : channel_(std::move(channel)),
50 stub_(channel_),
51 cq_(completion_queue),
52 callback_threadpool_(callback_threadpool),
53 getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
54 createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
55 deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
56 registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
57 deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
58 rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
59 cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
60 cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
61 recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
62 recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
63 logging_(Method(GrpcWorkerMethod::kLogging)),
64 tracing_(Method(GrpcWorkerMethod::kTracing)),
65 completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
66 instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
67 getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
68 markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
69 logger_(logger),
70 target_(target) {}
71
~GrpcRemoteWorker()72 ~GrpcRemoteWorker() override {}
73
GetStatusAsync(CallOptions * call_opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)74 void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
75 GetStatusResponse* response, bool fail_fast,
76 StatusCallback done) override {
77 IssueRequest(request, response, getstatus_, std::move(done), call_opts,
78 fail_fast);
79 }
80
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)81 void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
82 CreateWorkerSessionResponse* response,
83 StatusCallback done) override {
84 IssueRequest(request, response, createworkersession_, std::move(done));
85 }
86
DeleteWorkerSessionAsync(CallOptions * call_opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)87 void DeleteWorkerSessionAsync(CallOptions* call_opts,
88 const DeleteWorkerSessionRequest* request,
89 DeleteWorkerSessionResponse* response,
90 StatusCallback done) override {
91 IssueRequest(request, response, deleteworkersession_, std::move(done),
92 call_opts);
93 }
94
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)95 void RegisterGraphAsync(const RegisterGraphRequest* request,
96 RegisterGraphResponse* response,
97 StatusCallback done) override {
98 IssueRequest(request, response, registergraph_, std::move(done));
99 }
100
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)101 void DeregisterGraphAsync(const DeregisterGraphRequest* request,
102 DeregisterGraphResponse* response,
103 StatusCallback done) override {
104 IssueRequest(request, response, deregistergraph_, std::move(done));
105 }
106
RunGraphAsync(CallOptions * call_opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)107 void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
108 RunGraphResponse* response, StatusCallback done) override {
109 IssueRequest(request, response, rungraph_, std::move(done), call_opts);
110 }
RunGraphAsync(CallOptions * call_opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)111 void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
112 MutableRunGraphResponseWrapper* response,
113 StatusCallback done) override {
114 IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
115 rungraph_, std::move(done), call_opts);
116 }
117
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)118 void CleanupGraphAsync(const CleanupGraphRequest* request,
119 CleanupGraphResponse* response,
120 StatusCallback done) override {
121 IssueRequest(request, response, cleanupgraph_, std::move(done));
122 }
123
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)124 void CleanupAllAsync(const CleanupAllRequest* request,
125 CleanupAllResponse* response,
126 StatusCallback done) override {
127 IssueRequest(request, response, cleanupall_, std::move(done));
128 }
129
RecvBufAsync(CallOptions * call_opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)130 void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
131 RecvBufResponse* response, StatusCallback done) override {
132 int64 start_usec = Env::Default()->NowMicros();
133 // Type-specialized logging for this method.
134 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
135
136 auto callback = [this, request, response, done, start_usec,
137 logging_active](Status s) {
138 if (logging_active) {
139 if (logger_->LoggingActive()) {
140 int64 end_usec = Env::Default()->NowMicros();
141 int64 step_id = request->step_id();
142 RecvBufRespExtra extra;
143 response->transport_options().UnpackTo(&extra);
144 int64 num_bytes = 0;
145 for (const auto& chunk : extra.tensor_content()) {
146 num_bytes += chunk.size();
147 }
148 int64 send_start_usec = start_usec;
149 // Prefer start time reported by the sender, if available.
150 if (response->send_start_micros()) {
151 send_start_usec = std::max(
152 start_usec, static_cast<int64>(response->send_start_micros()));
153 send_start_usec = std::min(send_start_usec, end_usec - 1);
154 }
155 const string& key = request->buf_rendezvous_key();
156 logger_->RecordDataTransfer(
157 step_id, send_start_usec, end_usec, key, request->src_device(),
158 request->dst_device(), num_bytes, "", "RecvBuf");
159 }
160 VLOG(2) << "done callback, req: " << request->DebugString()
161 << " response " << response->DebugString();
162 }
163
164 // Note done() can delete this worker object, so we need to call done()
165 // last.
166 if (response->require_ack()) {
167 IssueMarkRecvFinishedRequest(request->request_id());
168 }
169 done(s);
170 };
171
172 IssueRequest(request, response, recvbuf_, callback, call_opts);
173 }
174
CompleteGroupAsync(CallOptions * call_opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)175 void CompleteGroupAsync(CallOptions* call_opts,
176 const CompleteGroupRequest* request,
177 CompleteGroupResponse* response,
178 StatusCallback done) override {
179 IssueRequest(request, response, completegroup_, std::move(done), call_opts,
180 /*fail_fast=*/false);
181 }
182
CompleteInstanceAsync(CallOptions * call_opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)183 void CompleteInstanceAsync(CallOptions* call_opts,
184 const CompleteInstanceRequest* request,
185 CompleteInstanceResponse* response,
186 StatusCallback done) override {
187 IssueRequest(request, response, instancesource_, std::move(done),
188 call_opts);
189 }
190
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)191 void GetStepSequenceAsync(const GetStepSequenceRequest* request,
192 GetStepSequenceResponse* response,
193 StatusCallback done) override {
194 IssueRequest(request, response, getstepsequence_, std::move(done));
195 }
196
RecvTensorAsync(CallOptions * call_opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)197 void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
198 TensorResponse* response, StatusCallback done) override {
199 VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
200 int64 start_usec = Env::Default()->NowMicros();
201 // Type-specialized logging for this method.
202 bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
203
204 auto callback = [this, request, response, done, start_usec,
205 logging_active](Status s) {
206 if (logging_active) {
207 if (logger_->LoggingActive()) {
208 int64 end_usec = Env::Default()->NowMicros();
209 int64 step_id = request->step_id();
210 int64 bytes = response->tensor().TotalBytes();
211 int64 send_start_usec = start_usec;
212 // If a send start time was reported by the other side, use
213 // that instead. Maybe we should mark the display if we're using
214 // our local time instead of the remote start time?
215 if (response->metadata().send_start_micros()) {
216 // send_start_micros is the timestamp taken when the
217 // remote machine began to send the RecvTensor response.
218 // Due to clock skew between source and dest machines, it
219 // is possible that send_start_micros can be larger than
220 // end_usec or less than start_usec.
221 //
222 // To respect causality, we enforce the invariants that
223 // the RecvTensor response can not have been sent before
224 // the RecvTensor request, and must have been sent before
225 // it was received.
226 send_start_usec = std::max(
227 start_usec,
228 static_cast<int64>(response->metadata().send_start_micros()));
229 send_start_usec = std::min(send_start_usec, end_usec - 1);
230 }
231 const string& key = request->rendezvous_key();
232 std::vector<string> key_parts = str_util::Split(key, ';');
233 if (key_parts.size() != 5) {
234 LOG(WARNING) << "Bad key: " << key;
235 } else {
236 logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
237 key_parts[3], // tensor name
238 key_parts[0], // src_device
239 key_parts[2], // dst_device
240 bytes);
241 }
242 }
243 VLOG(2) << "done callback, req: " << request->DebugString()
244 << " response " << response->metadata().DebugString();
245 }
246
247 // Note done() can delete this worker object, so we need to call done()
248 // last.
249 if (response->metadata().require_ack()) {
250 IssueMarkRecvFinishedRequest(request->request_id());
251 }
252 done(s);
253 };
254
255 IssueRequest(request, response, recvtensor_, callback, call_opts);
256 }
257
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)258 void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
259 StatusCallback done) override {
260 IssueRequest(request, response, logging_, done);
261 }
262
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)263 void TracingAsync(const TracingRequest* request, TracingResponse* response,
264 StatusCallback done) override {
265 IssueRequest(request, response, tracing_, done);
266 }
267
268 private:
269 // Utility method for issuing a generic asynchronous request. The
270 // given callback, `done`, will be called when the RPC completes.
IssueRequest(const protobuf::Message * request,protobuf::Message * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,bool fail_fast=true)271 void IssueRequest(const protobuf::Message* request,
272 protobuf::Message* response, const ::grpc::string& method,
273 StatusCallback done, CallOptions* call_opts = nullptr,
274 bool fail_fast = true) {
275 new RPCState<protobuf::Message>(
276 &stub_, cq_, method, *request, response, std::move(done), call_opts,
277 callback_threadpool_, MaxRetries(), fail_fast, &target_);
278 }
279
IssueRequest(const protobuf::Message * request,TensorResponse * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr)280 void IssueRequest(const protobuf::Message* request, TensorResponse* response,
281 const ::grpc::string& method, StatusCallback done,
282 CallOptions* call_opts = nullptr) {
283 new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
284 std::move(done), call_opts,
285 callback_threadpool_, MaxRetries(),
286 /*fail_fast=*/true, &target_);
287 }
288
IssueMarkRecvFinishedRequest(int64 request_id)289 void IssueMarkRecvFinishedRequest(int64 request_id) {
290 VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
291 MarkRecvFinishedRequest request;
292 request.set_request_id(request_id);
293
294 MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
295 auto done = [response](Status status) { delete response; };
296 IssueRequest(&request, response, markrecvfinished_, done);
297 }
298
299 // Helper function for initializing the RpcMethod objects below.
Method(GrpcWorkerMethod id)300 const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
301
302 // Helper function for configuring max GRPC retries. Defaults to 0 (no
303 // retries).
MaxRetries()304 const int64 MaxRetries() {
305 int64 max_retries = -1;
306 TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
307 return max_retries;
308 }
309
310 SharedGrpcChannelPtr channel_;
311 ::grpc::GenericStub stub_;
312 ::grpc::CompletionQueue* cq_;
313 thread::ThreadPool* callback_threadpool_;
314
315 const ::grpc::string getstatus_;
316 const ::grpc::string createworkersession_;
317 const ::grpc::string deleteworkersession_;
318 const ::grpc::string registergraph_;
319 const ::grpc::string deregistergraph_;
320 const ::grpc::string rungraph_;
321 const ::grpc::string cleanupgraph_;
322 const ::grpc::string cleanupall_;
323 const ::grpc::string recvtensor_;
324 const ::grpc::string recvbuf_;
325 const ::grpc::string logging_;
326 const ::grpc::string tracing_;
327 const ::grpc::string completegroup_;
328 const ::grpc::string instancesource_;
329 const ::grpc::string getstepsequence_;
330 const ::grpc::string markrecvfinished_;
331
332 // Support for logging.
333 WorkerCacheLogger* logger_;
334 const string target_;
335
336 TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
337 };
338
NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)339 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
340 ::grpc::CompletionQueue* completion_queue,
341 thread::ThreadPool* callback_threadpool,
342 WorkerCacheLogger* logger,
343 const string& target) {
344 return new GrpcRemoteWorker(std::move(channel), completion_queue,
345 callback_threadpool, logger, target);
346 }
347
348 } // namespace tensorflow
349