• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
17 
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22 
23 #include "grpcpp/alarm.h"
24 #include "grpcpp/server_builder.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
27 #include "tensorflow/core/common_runtime/copy_tensor.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/local_device.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
42 #include "tensorflow/core/distributed_runtime/worker.h"
43 #include "tensorflow/core/distributed_runtime/worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/worker_session.h"
45 #include "tensorflow/core/framework/cancellation.h"
46 #include "tensorflow/core/framework/collective.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/protobuf/transport_options.pb.h"
57 #include "tensorflow/core/protobuf/worker.pb.h"
58 
59 namespace tensorflow {
60 
61 namespace {
62 
63 // This macro creates a new request for the given RPC method name
64 // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
65 // `this->cq_`.
66 //
67 // This macro is invoked one or more times for each RPC method to
68 // ensure that there are sufficient completion queue entries to
69 // handle incoming requests without blocking.
70 //
71 // The implementation of the request handler for each RPC method
72 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
73 // to keep accepting new requests.
74 #define ENQUEUE_REQUEST(method, supports_cancel)                             \
75   do {                                                                       \
76     mutex_lock l(shutdown_mu_);                                              \
77     if (!is_shutdown_) {                                                     \
78       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,       \
79            method##Request, method##Response>::                              \
80           EnqueueRequestForMethod(                                           \
81               worker_service_, cq_.get(),                                    \
82               static_cast<int>(GrpcWorkerMethod::k##method),                 \
83               &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
84     }                                                                        \
85   } while (0)
86 
87 #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
88   for (int i = 0;                                                              \
89        i < gtl::FindWithDefault(queue_depth_,                                  \
90                                 static_cast<int>(GrpcWorkerMethod::k##method), \
91                                 default_depth);                                \
92        ++i) {                                                                  \
93     ENQUEUE_REQUEST(method, supports_cancel);                                  \
94   }
95 
96 // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
97 // requests.  Each thread operates on an independent completion queue.
98 class GrpcWorkerServiceThread {
99  public:
GrpcWorkerServiceThread(GrpcWorker * worker,::grpc::ServerBuilder * builder,std::unordered_map<int,int> queue_depth,GrpcResponseCache * cache,grpc::WorkerService::AsyncService * worker_service)100   explicit GrpcWorkerServiceThread(
101       GrpcWorker* worker, ::grpc::ServerBuilder* builder,
102       std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
103       grpc::WorkerService::AsyncService* worker_service)
104       : worker_(worker),
105         queue_depth_(queue_depth),
106         cache_(cache),
107         worker_service_(worker_service),
108         is_shutdown_(false) {
109     cq_ = builder->AddCompletionQueue();
110   }
111 
Start()112   void Start() {
113     thread_.reset(
114         worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
115                                          [this]() { HandleRPCsLoop(); }));
116   }
117 
Join()118   void Join() { thread_.reset(); }  // Blocks until thread exits
119 
Shutdown()120   void Shutdown() {
121     {
122       mutex_lock lock(shutdown_mu_);
123       is_shutdown_ = true;
124     }
125     cq_->Shutdown();
126   }
127 
128  private:
129   // Add one or more completion queue entries for each worker method, then
130   // begin servicing requests from the completion queue.
HandleRPCsLoop()131   void HandleRPCsLoop() {
132     // TODO(ncteisen): This may require performance engineering. We can
133     // change the number of threads, the number of handlers per thread,
134     // or even decide to specialize certain threads to certain methods.
135     SETUP_FOR_REQUEST(GetStatus, 1, false);
136     SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
137     SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
138     SETUP_FOR_REQUEST(CleanupAll, 1, false);
139     SETUP_FOR_REQUEST(RegisterGraph, 1, false);
140     SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
141     SETUP_FOR_REQUEST(Logging, 1, false);
142     SETUP_FOR_REQUEST(Tracing, 1, false);
143     SETUP_FOR_REQUEST(CompleteGroup, 10, true);
144     SETUP_FOR_REQUEST(CompleteInstance, 10, true);
145     SETUP_FOR_REQUEST(GetStepSequence, 10, true);
146     SETUP_FOR_REQUEST(RecvBuf, 500, true);
147     SETUP_FOR_REQUEST(RunGraph, 100, true);
148     SETUP_FOR_REQUEST(CleanupGraph, 100, false);
149     SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
150 
151     // TODO(ncteisen): Determine a better policy for enqueuing the
152     // appropriate number of each request type.
153     for (int i = 0;
154          i < gtl::FindWithDefault(
155                  queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
156                  1000);
157          ++i) {
158       EnqueueRecvTensorRequestRaw();
159     }
160 
161     void* tag;
162     bool ok;
163 
164     while (cq_->Next(&tag, &ok)) {
165       UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
166           static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
167       CHECK(callback_tag);
168       callback_tag->OnCompleted(this, ok);
169     }
170   }
171 
172  private:
Schedule(std::function<void ()> f)173   void Schedule(std::function<void()> f) {
174     worker_->env()->compute_pool->Schedule(std::move(f));
175   }
176 
177   // The following section contains one request handler method per
178   // RPC. The `FooHandler` method is called (indirectly) by
179   // `HandleRPCsLoop()` when the next Foo RPC is received. Each
180   // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
181   // and is responsible for requesting the next Foo call by calling
182   // `ENQUEUE_REQUEST(Foo)`.
183   template <class RequestMessage, class ResponseMessage>
184   using WorkerCall =
185       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
186            RequestMessage, ResponseMessage>;
187 
188   // Handle all non-cancellable simple methods with a standard wrapper.
189   // The boolean `may_block_on_compute_pool` indicates whether or not the
190   // operation may block on activities (such as op execution) that run on the
191   // compute pool.
192 #define HANDLE_CALL(method, may_block_on_compute_pool)                        \
193   void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
194     auto closure = [this, call]() {                                           \
195       Status s = worker_->method(&call->request, &call->response);            \
196       if (!s.ok()) {                                                          \
197         VLOG(1) << "Bad response from " << #method << ": " << s;              \
198       }                                                                       \
199       call->SendResponse(ToGrpcStatus(s));                                    \
200     };                                                                        \
201     if ((may_block_on_compute_pool)) {                                        \
202       worker_->env()->env->SchedClosure(std::move(closure));                  \
203     } else {                                                                  \
204       worker_->env()->compute_pool->Schedule(std::move(closure));             \
205     }                                                                         \
206     ENQUEUE_REQUEST(method, false);                                           \
207   }
208 
209   HANDLE_CALL(GetStatus, false);
210   HANDLE_CALL(CreateWorkerSession, false);
211   HANDLE_CALL(DeleteWorkerSession, true);
212   HANDLE_CALL(CleanupAll, false);
213   HANDLE_CALL(RegisterGraph, false);
214   HANDLE_CALL(DeregisterGraph, false);
215   HANDLE_CALL(CleanupGraph, false);
216   HANDLE_CALL(Logging, false);
217   HANDLE_CALL(Tracing, false);
218 
219 #undef HANDLE_CALL
220 
GetStepSequenceHandler(WorkerCall<GetStepSequenceRequest,GetStepSequenceResponse> * call)221   void GetStepSequenceHandler(
222       WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
223     Schedule([this, call]() {
224       worker_->GetStepSequenceAsync(
225           &call->request, &call->response, [call](const Status& s) {
226             VLOG(1) << "Bad response from GetStepSequence:" << s;
227             call->SendResponse(ToGrpcStatus(s));
228           });
229     });
230     ENQUEUE_REQUEST(GetStepSequence, true);
231   }
232 
MarkRecvFinishedHandler(WorkerCall<MarkRecvFinishedRequest,MarkRecvFinishedResponse> * call)233   void MarkRecvFinishedHandler(
234       WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
235     VLOG(1) << "Clean cache entry for request " << call->request.request_id();
236     worker_->RemoveCacheEntryForId(call->request.request_id());
237     call->SendResponse(::grpc::Status::OK);
238     ENQUEUE_REQUEST(MarkRecvFinished, false);
239   }
240 
RunGraphHandler(WorkerCall<RunGraphRequest,RunGraphResponse> * call)241   void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
242     Schedule([this, call]() {
243       CallOptions* call_opts = new CallOptions;
244       ProtoRunGraphRequest* wrapped_request =
245           new ProtoRunGraphRequest(&call->request);
246       NonOwnedProtoRunGraphResponse* wrapped_response =
247           new NonOwnedProtoRunGraphResponse(&call->response);
248       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
249       worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
250                              [call, call_opts, wrapped_request,
251                               wrapped_response](const Status& s) {
252                                VLOG(1) << "RunGraph::Done";
253                                if (!s.ok()) {
254                                  VLOG(1) << "Bad response from RunGraph:" << s;
255                                }
256                                call->ClearCancelCallback();
257                                delete call_opts;
258                                delete wrapped_request;
259                                delete wrapped_response;
260                                call->SendResponse(ToGrpcStatus(s));
261                              });
262     });
263     ENQUEUE_REQUEST(RunGraph, true);
264   }
265 
RecvTensorHandlerRaw(WorkerCall<RecvTensorRequest,::grpc::ByteBuffer> * call)266   void RecvTensorHandlerRaw(
267       WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
268     Schedule([this, call]() {
269       CallOptions* call_opts = new CallOptions;
270       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
271 
272       worker_->GrpcRecvTensorAsync(
273           call_opts, &call->request, &call->response,
274           [call, call_opts](const Status& s) {
275             call->ClearCancelCallback();
276             delete call_opts;
277             if (!s.ok()) {
278               VLOG(1) << "Bad response from RecvTensor:" << s;
279             }
280             call->SendResponse(ToGrpcStatus(s));
281           });
282     });
283     EnqueueRecvTensorRequestRaw();
284   }
285 
RecvBufHandler(WorkerCall<RecvBufRequest,RecvBufResponse> * call)286   void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
287     Schedule([this, call]() {
288       CallOptions* call_opts = new CallOptions;
289       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
290       worker_->RecvBufAsync(call_opts, &call->request, &call->response,
291                             [call, call_opts](const Status& s) {
292                               call->ClearCancelCallback();
293                               delete call_opts;
294                               if (!s.ok()) {
295                                 VLOG(1) << "Bad response from RecvBuf:" << s;
296                               }
297                               call->SendResponse(ToGrpcStatus(s));
298                             });
299     });
300     ENQUEUE_REQUEST(RecvBuf, true);
301   }
302 
CompleteGroupHandler(WorkerCall<CompleteGroupRequest,CompleteGroupResponse> * call)303   void CompleteGroupHandler(
304       WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
305     Schedule([this, call]() {
306       CallOptions* call_opts = new CallOptions;
307       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
308       worker_->CompleteGroupAsync(
309           call_opts, &call->request, &call->response,
310           [call, call_opts](const Status& s) {
311             call->ClearCancelCallback();
312             delete call_opts;
313             if (!s.ok()) {
314               VLOG(1) << "Bad response from CompleteGroup:" << s;
315             }
316             call->SendResponse(ToGrpcStatus(s));
317           });
318     });
319     ENQUEUE_REQUEST(CompleteGroup, true);
320   }
321 
CompleteInstanceHandler(WorkerCall<CompleteInstanceRequest,CompleteInstanceResponse> * call)322   void CompleteInstanceHandler(
323       WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
324     Schedule([this, call]() {
325       CallOptions* call_opts = new CallOptions;
326       call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
327       worker_->CompleteInstanceAsync(
328           call_opts, &call->request, &call->response,
329           [call, call_opts](const Status& s) {
330             call->ClearCancelCallback();
331             delete call_opts;
332             if (!s.ok()) {
333               VLOG(1) << "Bad response from CompleteInstance:" << s;
334             }
335             call->SendResponse(ToGrpcStatus(s));
336           });
337     });
338     ENQUEUE_REQUEST(CompleteInstance, false);
339   }
340 #undef ENQUEUE_REQUEST
341 
EnqueueRecvTensorRequestRaw()342   void EnqueueRecvTensorRequestRaw() {
343     mutex_lock l(shutdown_mu_);
344     if (!is_shutdown_) {
345       Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
346            RecvTensorRequest, ::grpc::ByteBuffer>::
347           EnqueueRequestForMethod(
348               worker_service_, cq_.get(),
349               static_cast<int>(GrpcWorkerMethod::kRecvTensor),
350               &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
351               true /* supports cancel*/);
352     }
353   }
354 
355   GrpcWorker* const worker_ = nullptr;  // Not owned.
356   std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
357   std::unique_ptr<Thread> thread_;
358   std::unordered_map<int, int> queue_depth_;
359   GrpcResponseCache* cache_;
360   grpc::WorkerService::AsyncService* const worker_service_;
361 
362   mutex shutdown_mu_;
363   bool is_shutdown_ GUARDED_BY(shutdown_mu_);
364   TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
365 };
366 
367 class GrpcWorkerService : public AsyncServiceInterface {
368  public:
GrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)369   GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
370                     GrpcWorkerServiceOptions options)
371       : is_shutdown_(false) {
372     builder->RegisterService(&worker_service_);
373 
374     for (int i = 0; i < options.num_serving_threads; i++) {
375       threads_.emplace_back(
376           new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
377                                       cache_.get(), &worker_service_));
378     }
379   }
380 
Shutdown()381   void Shutdown() override {
382     bool did_shutdown = false;
383     {
384       mutex_lock l(service_shutdown_mu_);
385       if (!is_shutdown_) {
386         LOG(INFO) << "Shutting down GrpcWorkerService.";
387         is_shutdown_ = true;
388         did_shutdown = true;
389       }
390     }
391     if (did_shutdown) {
392       for (auto& worker_thread : threads_) {
393         worker_thread->Shutdown();
394       }
395     }
396   }
397 
398   // This method blocks forever handling requests from the completion queue.
HandleRPCsLoop()399   void HandleRPCsLoop() override {
400     for (auto& worker_thread : threads_) {
401       worker_thread->Start();
402     }
403     for (auto& worker_thread : threads_) {
404       worker_thread->Join();
405     }
406   }
407 
408  private:
409   grpc::WorkerService::AsyncService worker_service_;
410   std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
411 
412   std::unique_ptr<GrpcResponseCache> cache_;
413   mutex service_shutdown_mu_;
414   bool is_shutdown_ GUARDED_BY(service_shutdown_mu_);
415 
416   TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
417 };
418 
419 }  // namespace
420 
GrpcWorker(WorkerEnv * worker_env,const ConfigProto & config)421 GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
422     : Worker(worker_env),
423       recv_buf_max_chunk_(
424           config.experimental().recv_buf_max_chunk() > 0
425               ? config.experimental().recv_buf_max_chunk()
426               : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {
427   if (config.rpc_options().cache_rpc_response()) {
428     EnableResponseCache();
429   }
430 }
431 
EnableResponseCache()432 void GrpcWorker::EnableResponseCache() {
433   VLOG(1) << "Enabling gRPC tensor response cache.";
434   response_cache_ = absl::make_unique<GrpcResponseCache>();
435 }
436 
437 // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
438 // buffers for a response object, to avoid extra protocol buffer serialization
439 // overhead we generate our response directly into a ::grpc::ByteBuffer object
GrpcRecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,::grpc::ByteBuffer * response,StatusCallback done)440 void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
441                                      const RecvTensorRequest* request,
442                                      ::grpc::ByteBuffer* response,
443                                      StatusCallback done) {
444   VLOG(1) << "GrpcRecvTensorAsync req: " << request->DebugString();
445   const int64 request_id = request->request_id();
446   const int64 step_id = request->step_id();
447 
448   bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
449 
450   auto do_response = [response, done, cache_enabled](const Tensor& tensor,
451                                                      bool is_dead,
452                                                      const Status& status) {
453     if (status.ok()) {
454       grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
455     }
456     done(status);
457   };
458 
459   // If response cache is enabled and the response cache already contains the
460   // request, we delegate this retry request to the response cache. Otherwise,
461   // we add the request to the response cache and start the computation to
462   // retrieve the requested data.
463   if (cache_enabled &&
464       response_cache_->QueueRequest(request_id, step_id, do_response)) {
465     return;
466   }
467 
468   auto rendezvous_done = [this, request_id, do_response, cache_enabled](
469                              const Tensor& tensor, bool is_dead,
470                              const Status& status) {
471     if (cache_enabled) {
472       // Data is ready. Process all pending requests in the response cache.
473       response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
474     } else {
475       do_response(tensor, is_dead, status);
476     }
477   };
478 
479   auto fail = [&rendezvous_done](const Status& status) {
480     rendezvous_done(Tensor(), false, status);
481   };
482 
483   Status s = recent_request_ids_.TrackUnique(
484       request_id, "RecvTensor (GrpcWorker)", *request);
485   if (!s.ok()) {
486     fail(s);
487     return;
488   }
489 
490   const string& key = request->rendezvous_key();
491   TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
492   Rendezvous::ParsedKey parsed;
493   s = Rendezvous::ParseKey(key, &parsed);
494   Device* src_dev = nullptr;
495   if (s.ok()) {
496     s = PrepareRecvTensor(parsed, &src_dev);
497   }
498   if (!s.ok()) {
499     fail(s);
500     return;
501   }
502 
503   // Request the tensor associated with the rendezvous key.
504   // Note that we log the cancellation here but do not abort the current step.
505   // gRPC can generate cancellations in response to transient network failures,
506   // and aborting the step eliminates the opportunity for client side retries.
507   // Repeated client failures will eventually cause the step to be aborted by
508   // the client.
509   opts->SetCancelCallback(
510       [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
511   env_->rendezvous_mgr->RecvLocalAsync(
512       step_id, parsed,
513       [opts, rendezvous_done, src_dev, request](
514           const Status& status, const Rendezvous::Args& send_args,
515           const Rendezvous::Args& recv_args, const Tensor& val,
516           const bool is_dead) {
517         opts->ClearCancelCallback();
518         if (status.ok()) {
519           // DMA can only be used for Tensors that do not fall into
520           // the following three odd edge cases: 1) a zero-size
521           // buffer, 2) a dead tensor which has an uninit value, and
522           // 3) the tensor has the on_host allocation attribute,
523           // i.e. it's in CPU RAM *independent of its assigned
524           // device type*.
525           const bool on_host = send_args.alloc_attrs.on_host();
526           {
527             // Non-DMA cases.
528             if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
529               DeviceContext* send_dev_context = send_args.device_context;
530               AllocatorAttributes alloc_attrs;
531               alloc_attrs.set_gpu_compatible(true);
532               alloc_attrs.set_on_host(true);
533               Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
534               Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
535               CHECK(send_dev_context)
536                   << "send dev name: " << src_dev->name()
537                   << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
538               // "val" is on an accelerator device. Uses the device_context to
539               // fill the copy on host.
540               StatusCallback copy_ready = [rendezvous_done, copy,
541                                            is_dead](const Status& s) {
542                 // The value is now ready to be returned on the wire.
543                 rendezvous_done(*copy, is_dead, s);
544                 delete copy;
545               };
546 
547               CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
548                                src_dev, copy, send_dev_context, copy_ready);
549               return;
550             }
551           }
552         }
553 
554         rendezvous_done(val, is_dead, status);
555       });
556 }
557 
558 namespace {
559 // If RecvBufRespExtra.tensor_content is a single large string, then gRPC
560 // can stall on the recv side when the string buffer needs to be enlarged,
561 // since the size is not sent in advance.  Changing this field to a sequence
562 // of small strings costs some extra time on the send side, since we do
563 // some otherwise unnecessary copies, but it improves runtime overall by
564 // improving flow control.  Best performance is likely achieved with a
565 // max_chunk_bytes equal to the memory page size.
566 //
567 // TODO(tucker): When proto3 supports [ctype=CORD] then change
568 // RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
569 // and remove this function.
SetTensorInRecvBufResp(int64 max_chunk_bytes,const Tensor * tensor,RecvBufResponse * response)570 void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
571                             RecvBufResponse* response) {
572   RecvBufRespExtra extra;
573   int64 num_bytes = tensor->TotalBytes();
574   const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
575   while (num_bytes > 0) {
576     int64 bytes =
577         max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
578     extra.add_tensor_content(std::string(head, bytes));
579     head += bytes;
580     num_bytes -= bytes;
581   }
582   response->mutable_transport_options()->PackFrom(extra);
583 }
584 }  // namespace
585 
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)586 void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
587                               RecvBufResponse* response, StatusCallback done) {
588   const int64 request_id = request->request_id();
589   const int64 step_id = request->step_id();
590   bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
591 
592   auto do_response = [this, response, done, cache_enabled](
593                          const Tensor& tensor, bool is_dead,
594                          const Status& status) {
595     if (status.ok()) {
596       SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
597     }
598     response->set_send_start_micros(env_->env->NowMicros());
599     response->set_require_ack(cache_enabled);
600     done(status);
601   };
602 
603   // If response cache is enabled and the response cache already contains the
604   // request, we delegate this retry request to the response cache. Otherwise,
605   // we add the request to the response cache and start the computation to
606   // retrieve the requested data.
607   if (cache_enabled &&
608       response_cache_->QueueRequest(request_id, step_id, do_response)) {
609     return;
610   }
611 
612   auto rendezvous_done = [this, request_id, do_response, cache_enabled](
613                              const Tensor& tensor, const Status& status) {
614     if (cache_enabled) {
615       // Data is ready. Process all pending requests in the response cache.
616       response_cache_->OnRequestFinished(request_id, tensor, false, status);
617     } else {
618       do_response(tensor, false, status);
619     }
620   };
621 
622   auto fail = [&rendezvous_done](const Status& status) {
623     rendezvous_done(Tensor(), status);
624   };
625 
626   // This is a generic, low performance implementation appropriate for grpc.
627   Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
628                                              *request);
629   if (!s.ok()) {
630     fail(s);
631     return;
632   }
633 
634   CollectiveExecutor::Handle ce_handle(
635       env_->collective_executor_mgr->FindOrCreate(step_id), true);
636   CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
637   auto consumer_callback = [this, request, rendezvous_done](
638                                const Status& status,
639                                BufRendezvous::Hook* hook) {
640     Status s = status;
641     if (s.ok()) {
642       if (hook == nullptr) {
643         s = errors::Internal("Invalid null hook for key ",
644                              request->buf_rendezvous_key());
645       }
646       if (!DMAHelper::CanUseDMA(hook->prod_value)) {
647         s = errors::Internal("Tensor value for key ",
648                              request->buf_rendezvous_key(),
649                              " is not of a type supported by RecvBuf");
650       }
651     } else {
652       if (hook != nullptr) {
653         LOG(ERROR) << "Got hook " << hook << " with status " << s
654                    << " from ConsumeBuf";
655       }
656     }
657 
658     if (s.ok()) {
659       // The RPC source tensor needs to be in CPU RAM.  If not already
660       // there make a copy using memory appropriate to the purpose.
661       const size_t num_bytes = hook->prod_value->TotalBytes();
662       const bool on_host =
663           hook->prod_dev->attributes().device_type() == "CPU" ||
664           hook->prod_attr.on_host();
665       if ((!on_host) && (num_bytes > 0)) {
666         Device* cpu_dev = nullptr;
667         s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
668         if (s.ok()) {
669           AllocatorAttributes cpu_attr;
670           cpu_attr.set_gpu_compatible(true);
671           cpu_attr.set_nic_compatible(true);
672           Tensor* cpu_tensor =
673               new Tensor(cpu_dev->GetAllocator(cpu_attr),
674                          hook->prod_value->dtype(), hook->prod_value->shape());
675           hook->prod_ctx->CopyDeviceTensorToCPU(
676               hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
677               [hook, cpu_tensor, rendezvous_done](const Status& s) {
678                 rendezvous_done(*cpu_tensor, s);
679                 BufRendezvous::DoneWithHook(hook);
680                 delete cpu_tensor;
681               });
682           return;
683         }
684       }
685     }
686 
687     if (hook == nullptr) {
688       rendezvous_done(Tensor(), s);
689     } else {
690       rendezvous_done(*hook->prod_value, s);
691       BufRendezvous::DoneWithHook(hook);
692     }
693   };
694   rma->buf_rendezvous()->ConsumeBuf(
695       request->buf_rendezvous_key(), request->src_device(),
696       request->src_incarnation(), consumer_callback);
697 }
698 
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)699 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
700                               LoggingResponse* response, StatusCallback done) {
701   auto env = this->env();
702   if (env) {
703     auto session_mgr = env->session_mgr;
704     if (session_mgr) {
705       if (request->enable_rpc_logging()) {
706         session_mgr->SetLogging(true);
707       }
708       // NOTE(mrry): Handle old masters that disable RPC logging by setting
709       // `request->enable_rpc_logging` to `false`.
710       if (request->disable_rpc_logging() ||
711           (!request->enable_rpc_logging() &&
712            request->fetch_step_id_size() == 0)) {
713         session_mgr->SetLogging(false);
714       }
715       for (const auto& step_id : request->fetch_step_id()) {
716         session_mgr->RetrieveLogs(step_id, response);
717       }
718       if (request->clear()) {
719         session_mgr->ClearLogs();
720       }
721     }
722   }
723   done(Status::OK());
724 }
725 
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)726 void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
727                                    CleanupGraphResponse* response,
728                                    StatusCallback done) {
729   if (response_cache_) {
730     // Cleanup any stale response cache entries for this step. This can occur if
731     // a worker crashes before acking a request.
732     response_cache_->CleanEntriesForStep(request->step_id());
733   }
734   Worker::CleanupGraphAsync(request, response, done);
735 }
736 
env()737 WorkerEnv* GrpcWorker::env() { return env_; }
738 
RemoveCacheEntryForId(int64 request_id)739 void GrpcWorker::RemoveCacheEntryForId(int64 request_id) {
740   if (response_cache_) {
741     response_cache_->EraseRequestId(request_id);
742   }
743 }
744 
NewGrpcWorker(WorkerEnv * env,const ConfigProto & config)745 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
746                                           const ConfigProto& config) {
747   return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
748 }
749 
NewGrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)750 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
751     GrpcWorker* worker, ::grpc::ServerBuilder* builder,
752     GrpcWorkerServiceOptions options) {
753   return std::unique_ptr<AsyncServiceInterface>(
754       new GrpcWorkerService(worker, builder, options));
755 }
756 
757 }  // namespace tensorflow
758