• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
17 
18 #include <utility>
19 
20 #include "grpcpp/generic/generic_stub.h"
21 #include "grpcpp/grpcpp.h"
22 #include "tensorflow/core/common_runtime/process_util.h"
23 #include "tensorflow/core/distributed_runtime/call_options.h"
24 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
28 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
29 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
30 #include "tensorflow/core/distributed_runtime/worker_interface.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/core/threadpool.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/tracing.h"
37 #include "tensorflow/core/protobuf/transport_options.pb.h"
38 #include "tensorflow/core/protobuf/worker.pb.h"
39 #include "tensorflow/core/util/env_var.h"
40 
41 namespace tensorflow {
42 
43 class GrpcRemoteWorker : public WorkerInterface {
44  public:
GrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)45   explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
46                             ::grpc::CompletionQueue* completion_queue,
47                             thread::ThreadPool* callback_threadpool,
48                             WorkerCacheLogger* logger, const string& target)
49       : channel_(std::move(channel)),
50         stub_(channel_),
51         cq_(completion_queue),
52         callback_threadpool_(callback_threadpool),
53         getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
54         createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
55         deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
56         registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
57         deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
58         rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
59         cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
60         cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
61         recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
62         recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
63         logging_(Method(GrpcWorkerMethod::kLogging)),
64         tracing_(Method(GrpcWorkerMethod::kTracing)),
65         completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
66         instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
67         getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
68         markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
69         logger_(logger),
70         target_(target) {}
71 
~GrpcRemoteWorker()72   ~GrpcRemoteWorker() override {}
73 
GetStatusAsync(CallOptions * call_opts,const GetStatusRequest * request,GetStatusResponse * response,bool fail_fast,StatusCallback done)74   void GetStatusAsync(CallOptions* call_opts, const GetStatusRequest* request,
75                       GetStatusResponse* response, bool fail_fast,
76                       StatusCallback done) override {
77     IssueRequest(request, response, getstatus_, std::move(done), call_opts,
78                  fail_fast);
79   }
80 
CreateWorkerSessionAsync(const CreateWorkerSessionRequest * request,CreateWorkerSessionResponse * response,StatusCallback done)81   void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
82                                 CreateWorkerSessionResponse* response,
83                                 StatusCallback done) override {
84     IssueRequest(request, response, createworkersession_, std::move(done));
85   }
86 
DeleteWorkerSessionAsync(CallOptions * call_opts,const DeleteWorkerSessionRequest * request,DeleteWorkerSessionResponse * response,StatusCallback done)87   void DeleteWorkerSessionAsync(CallOptions* call_opts,
88                                 const DeleteWorkerSessionRequest* request,
89                                 DeleteWorkerSessionResponse* response,
90                                 StatusCallback done) override {
91     IssueRequest(request, response, deleteworkersession_, std::move(done),
92                  call_opts);
93   }
94 
RegisterGraphAsync(const RegisterGraphRequest * request,RegisterGraphResponse * response,StatusCallback done)95   void RegisterGraphAsync(const RegisterGraphRequest* request,
96                           RegisterGraphResponse* response,
97                           StatusCallback done) override {
98     IssueRequest(request, response, registergraph_, std::move(done));
99   }
100 
DeregisterGraphAsync(const DeregisterGraphRequest * request,DeregisterGraphResponse * response,StatusCallback done)101   void DeregisterGraphAsync(const DeregisterGraphRequest* request,
102                             DeregisterGraphResponse* response,
103                             StatusCallback done) override {
104     IssueRequest(request, response, deregistergraph_, std::move(done));
105   }
106 
RunGraphAsync(CallOptions * call_opts,const RunGraphRequest * request,RunGraphResponse * response,StatusCallback done)107   void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
108                      RunGraphResponse* response, StatusCallback done) override {
109     IssueRequest(request, response, rungraph_, std::move(done), call_opts);
110   }
RunGraphAsync(CallOptions * call_opts,RunGraphRequestWrapper * request,MutableRunGraphResponseWrapper * response,StatusCallback done)111   void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
112                      MutableRunGraphResponseWrapper* response,
113                      StatusCallback done) override {
114     IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
115                  rungraph_, std::move(done), call_opts);
116   }
117 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)118   void CleanupGraphAsync(const CleanupGraphRequest* request,
119                          CleanupGraphResponse* response,
120                          StatusCallback done) override {
121     IssueRequest(request, response, cleanupgraph_, std::move(done));
122   }
123 
CleanupAllAsync(const CleanupAllRequest * request,CleanupAllResponse * response,StatusCallback done)124   void CleanupAllAsync(const CleanupAllRequest* request,
125                        CleanupAllResponse* response,
126                        StatusCallback done) override {
127     IssueRequest(request, response, cleanupall_, std::move(done));
128   }
129 
RecvBufAsync(CallOptions * call_opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)130   void RecvBufAsync(CallOptions* call_opts, const RecvBufRequest* request,
131                     RecvBufResponse* response, StatusCallback done) override {
132     int64 start_usec = Env::Default()->NowMicros();
133     // Type-specialized logging for this method.
134     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
135 
136     auto callback = [this, request, response, done, start_usec,
137                      logging_active](Status s) {
138       if (logging_active) {
139         if (logger_->LoggingActive()) {
140           int64 end_usec = Env::Default()->NowMicros();
141           int64 step_id = request->step_id();
142           RecvBufRespExtra extra;
143           response->transport_options().UnpackTo(&extra);
144           int64 num_bytes = 0;
145           for (const auto& chunk : extra.tensor_content()) {
146             num_bytes += chunk.size();
147           }
148           int64 send_start_usec = start_usec;
149           // Prefer start time reported by the sender, if available.
150           if (response->send_start_micros()) {
151             send_start_usec = std::max(
152                 start_usec, static_cast<int64>(response->send_start_micros()));
153             send_start_usec = std::min(send_start_usec, end_usec - 1);
154           }
155           const string& key = request->buf_rendezvous_key();
156           logger_->RecordDataTransfer(
157               step_id, send_start_usec, end_usec, key, request->src_device(),
158               request->dst_device(), num_bytes, "", "RecvBuf");
159         }
160         VLOG(2) << "done callback, req: " << request->DebugString()
161                 << " response " << response->DebugString();
162       }
163 
164       // Note done() can delete this worker object, so we need to call done()
165       // last.
166       if (response->require_ack()) {
167         IssueMarkRecvFinishedRequest(request->request_id());
168       }
169       done(s);
170     };
171 
172     IssueRequest(request, response, recvbuf_, callback, call_opts);
173   }
174 
CompleteGroupAsync(CallOptions * call_opts,const CompleteGroupRequest * request,CompleteGroupResponse * response,StatusCallback done)175   void CompleteGroupAsync(CallOptions* call_opts,
176                           const CompleteGroupRequest* request,
177                           CompleteGroupResponse* response,
178                           StatusCallback done) override {
179     IssueRequest(request, response, completegroup_, std::move(done), call_opts,
180                  /*fail_fast=*/false);
181   }
182 
CompleteInstanceAsync(CallOptions * call_opts,const CompleteInstanceRequest * request,CompleteInstanceResponse * response,StatusCallback done)183   void CompleteInstanceAsync(CallOptions* call_opts,
184                              const CompleteInstanceRequest* request,
185                              CompleteInstanceResponse* response,
186                              StatusCallback done) override {
187     IssueRequest(request, response, instancesource_, std::move(done),
188                  call_opts);
189   }
190 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,StatusCallback done)191   void GetStepSequenceAsync(const GetStepSequenceRequest* request,
192                             GetStepSequenceResponse* response,
193                             StatusCallback done) override {
194     IssueRequest(request, response, getstepsequence_, std::move(done));
195   }
196 
RecvTensorAsync(CallOptions * call_opts,const RecvTensorRequest * request,TensorResponse * response,StatusCallback done)197   void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
198                        TensorResponse* response, StatusCallback done) override {
199     VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
200     int64 start_usec = Env::Default()->NowMicros();
201     // Type-specialized logging for this method.
202     bool logging_active = logger_->LoggingActive() || VLOG_IS_ON(2);
203 
204     auto callback = [this, request, response, done, start_usec,
205                      logging_active](Status s) {
206       if (logging_active) {
207         if (logger_->LoggingActive()) {
208           int64 end_usec = Env::Default()->NowMicros();
209           int64 step_id = request->step_id();
210           int64 bytes = response->tensor().TotalBytes();
211           int64 send_start_usec = start_usec;
212           // If a send start time was reported by the other side, use
213           // that instead.  Maybe we should mark the display if we're using
214           // our local time instead of the remote start time?
215           if (response->metadata().send_start_micros()) {
216             // send_start_micros is the timestamp taken when the
217             // remote machine began to send the RecvTensor response.
218             // Due to clock skew between source and dest machines, it
219             // is possible that send_start_micros can be larger than
220             // end_usec or less than start_usec.
221             //
222             // To respect causality, we enforce the invariants that
223             // the RecvTensor response can not have been sent before
224             // the RecvTensor request, and must have been sent before
225             // it was received.
226             send_start_usec = std::max(
227                 start_usec,
228                 static_cast<int64>(response->metadata().send_start_micros()));
229             send_start_usec = std::min(send_start_usec, end_usec - 1);
230           }
231           const string& key = request->rendezvous_key();
232           std::vector<string> key_parts = str_util::Split(key, ';');
233           if (key_parts.size() != 5) {
234             LOG(WARNING) << "Bad key: " << key;
235           } else {
236             logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
237                                       key_parts[3],  // tensor name
238                                       key_parts[0],  // src_device
239                                       key_parts[2],  // dst_device
240                                       bytes);
241           }
242         }
243         VLOG(2) << "done callback, req: " << request->DebugString()
244                 << " response " << response->metadata().DebugString();
245       }
246 
247       // Note done() can delete this worker object, so we need to call done()
248       // last.
249       if (response->metadata().require_ack()) {
250         IssueMarkRecvFinishedRequest(request->request_id());
251       }
252       done(s);
253     };
254 
255     IssueRequest(request, response, recvtensor_, callback, call_opts);
256   }
257 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)258   void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
259                     StatusCallback done) override {
260     IssueRequest(request, response, logging_, done);
261   }
262 
TracingAsync(const TracingRequest * request,TracingResponse * response,StatusCallback done)263   void TracingAsync(const TracingRequest* request, TracingResponse* response,
264                     StatusCallback done) override {
265     IssueRequest(request, response, tracing_, done);
266   }
267 
268  private:
269   // Utility method for issuing a generic asynchronous request. The
270   // given callback, `done`, will be called when the RPC completes.
IssueRequest(const protobuf::Message * request,protobuf::Message * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr,bool fail_fast=true)271   void IssueRequest(const protobuf::Message* request,
272                     protobuf::Message* response, const ::grpc::string& method,
273                     StatusCallback done, CallOptions* call_opts = nullptr,
274                     bool fail_fast = true) {
275     new RPCState<protobuf::Message>(
276         &stub_, cq_, method, *request, response, std::move(done), call_opts,
277         callback_threadpool_, MaxRetries(), fail_fast, &target_);
278   }
279 
IssueRequest(const protobuf::Message * request,TensorResponse * response,const::grpc::string & method,StatusCallback done,CallOptions * call_opts=nullptr)280   void IssueRequest(const protobuf::Message* request, TensorResponse* response,
281                     const ::grpc::string& method, StatusCallback done,
282                     CallOptions* call_opts = nullptr) {
283     new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
284                                  std::move(done), call_opts,
285                                  callback_threadpool_, MaxRetries(),
286                                  /*fail_fast=*/true, &target_);
287   }
288 
IssueMarkRecvFinishedRequest(int64 request_id)289   void IssueMarkRecvFinishedRequest(int64 request_id) {
290     VLOG(2) << "Send MarkRecvFinishedRequest for request " << request_id;
291     MarkRecvFinishedRequest request;
292     request.set_request_id(request_id);
293 
294     MarkRecvFinishedResponse* response = new MarkRecvFinishedResponse();
295     auto done = [response](Status status) { delete response; };
296     IssueRequest(&request, response, markrecvfinished_, done);
297   }
298 
299   // Helper function for initializing the RpcMethod objects below.
Method(GrpcWorkerMethod id)300   const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
301 
302   // Helper function for configuring max GRPC retries. Defaults to 0 (no
303   // retries).
MaxRetries()304   const int64 MaxRetries() {
305     int64 max_retries = -1;
306     TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
307     return max_retries;
308   }
309 
310   SharedGrpcChannelPtr channel_;
311   ::grpc::GenericStub stub_;
312   ::grpc::CompletionQueue* cq_;
313   thread::ThreadPool* callback_threadpool_;
314 
315   const ::grpc::string getstatus_;
316   const ::grpc::string createworkersession_;
317   const ::grpc::string deleteworkersession_;
318   const ::grpc::string registergraph_;
319   const ::grpc::string deregistergraph_;
320   const ::grpc::string rungraph_;
321   const ::grpc::string cleanupgraph_;
322   const ::grpc::string cleanupall_;
323   const ::grpc::string recvtensor_;
324   const ::grpc::string recvbuf_;
325   const ::grpc::string logging_;
326   const ::grpc::string tracing_;
327   const ::grpc::string completegroup_;
328   const ::grpc::string instancesource_;
329   const ::grpc::string getstepsequence_;
330   const ::grpc::string markrecvfinished_;
331 
332   // Support for logging.
333   WorkerCacheLogger* logger_;
334   const string target_;
335 
336   TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
337 };
338 
NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,::grpc::CompletionQueue * completion_queue,thread::ThreadPool * callback_threadpool,WorkerCacheLogger * logger,const string & target)339 WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
340                                      ::grpc::CompletionQueue* completion_queue,
341                                      thread::ThreadPool* callback_threadpool,
342                                      WorkerCacheLogger* logger,
343                                      const string& target) {
344   return new GrpcRemoteWorker(std::move(channel), completion_queue,
345                               callback_threadpool, logger, target);
346 }
347 
348 }  // namespace tensorflow
349