1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "grpcpp/alarm.h"
24 #include "grpcpp/server_builder.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
27 #include "tensorflow/core/common_runtime/copy_tensor.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/local_device.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
42 #include "tensorflow/core/distributed_runtime/worker.h"
43 #include "tensorflow/core/distributed_runtime/worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/worker_session.h"
45 #include "tensorflow/core/framework/cancellation.h"
46 #include "tensorflow/core/framework/collective.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/protobuf/transport_options.pb.h"
57 #include "tensorflow/core/protobuf/worker.pb.h"
58
59 namespace tensorflow {
60
61 namespace {
62
63 // This macro creates a new request for the given RPC method name
64 // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
65 // `this->cq_`.
66 //
67 // This macro is invoked one or more times for each RPC method to
68 // ensure that there are sufficient completion queue entries to
69 // handle incoming requests without blocking.
70 //
71 // The implementation of the request handler for each RPC method
72 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
73 // to keep accepting new requests.
74 #define ENQUEUE_REQUEST(method, supports_cancel) \
75 do { \
76 mutex_lock l(shutdown_mu_); \
77 if (!is_shutdown_) { \
78 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService, \
79 method##Request, method##Response>:: \
80 EnqueueRequestForMethod( \
81 worker_service_, cq_.get(), \
82 static_cast<int>(GrpcWorkerMethod::k##method), \
83 &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
84 } \
85 } while (0)
86
87 #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \
88 for (int i = 0; \
89 i < gtl::FindWithDefault(queue_depth_, \
90 static_cast<int>(GrpcWorkerMethod::k##method), \
91 default_depth); \
92 ++i) { \
93 ENQUEUE_REQUEST(method, supports_cancel); \
94 }
95
96 // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
97 // requests. Each thread operates on an independent completion queue.
98 class GrpcWorkerServiceThread {
99 public:
GrpcWorkerServiceThread(GrpcWorker * worker,::grpc::ServerBuilder * builder,std::unordered_map<int,int> queue_depth,GrpcResponseCache * cache,grpc::WorkerService::AsyncService * worker_service)100 explicit GrpcWorkerServiceThread(
101 GrpcWorker* worker, ::grpc::ServerBuilder* builder,
102 std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
103 grpc::WorkerService::AsyncService* worker_service)
104 : worker_(worker),
105 queue_depth_(queue_depth),
106 cache_(cache),
107 worker_service_(worker_service),
108 is_shutdown_(false) {
109 cq_ = builder->AddCompletionQueue();
110 }
111
Start()112 void Start() {
113 thread_.reset(
114 worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
115 [this]() { HandleRPCsLoop(); }));
116 }
117
Join()118 void Join() { thread_.reset(); } // Blocks until thread exits
119
Shutdown()120 void Shutdown() {
121 {
122 mutex_lock lock(shutdown_mu_);
123 is_shutdown_ = true;
124 }
125 cq_->Shutdown();
126 }
127
128 private:
129 // Add one or more completion queue entries for each worker method, then
130 // begin servicing requests from the completion queue.
HandleRPCsLoop()131 void HandleRPCsLoop() {
132 // TODO(ncteisen): This may require performance engineering. We can
133 // change the number of threads, the number of handlers per thread,
134 // or even decide to specialize certain threads to certain methods.
135 SETUP_FOR_REQUEST(GetStatus, 1, false);
136 SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
137 SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
138 SETUP_FOR_REQUEST(CleanupAll, 1, false);
139 SETUP_FOR_REQUEST(RegisterGraph, 1, false);
140 SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
141 SETUP_FOR_REQUEST(Logging, 1, false);
142 SETUP_FOR_REQUEST(Tracing, 1, false);
143 SETUP_FOR_REQUEST(CompleteGroup, 10, true);
144 SETUP_FOR_REQUEST(CompleteInstance, 10, true);
145 SETUP_FOR_REQUEST(GetStepSequence, 10, true);
146 SETUP_FOR_REQUEST(RecvBuf, 500, true);
147 SETUP_FOR_REQUEST(RunGraph, 100, true);
148 SETUP_FOR_REQUEST(CleanupGraph, 100, false);
149 SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
150
151 // TODO(ncteisen): Determine a better policy for enqueuing the
152 // appropriate number of each request type.
153 for (int i = 0;
154 i < gtl::FindWithDefault(
155 queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
156 1000);
157 ++i) {
158 EnqueueRecvTensorRequestRaw();
159 }
160
161 void* tag;
162 bool ok;
163
164 while (cq_->Next(&tag, &ok)) {
165 UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
166 static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
167 CHECK(callback_tag);
168 callback_tag->OnCompleted(this, ok);
169 }
170 }
171
172 private:
Schedule(std::function<void ()> f)173 void Schedule(std::function<void()> f) {
174 worker_->env()->compute_pool->Schedule(std::move(f));
175 }
176
177 // The following section contains one request handler method per
178 // RPC. The `FooHandler` method is called (indirectly) by
179 // `HandleRPCsLoop()` when the next Foo RPC is received. Each
180 // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
181 // and is responsible for requesting the next Foo call by calling
182 // `ENQUEUE_REQUEST(Foo)`.
183 template <class RequestMessage, class ResponseMessage>
184 using WorkerCall =
185 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
186 RequestMessage, ResponseMessage>;
187
188 // Handle all non-cancellable simple methods with a standard wrapper.
189 // The boolean `may_block_on_compute_pool` indicates whether or not the
190 // operation may block on activities (such as op execution) that run on the
191 // compute pool.
192 #define HANDLE_CALL(method, may_block_on_compute_pool) \
193 void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
194 auto closure = [this, call]() { \
195 Status s = worker_->method(&call->request, &call->response); \
196 if (!s.ok()) { \
197 VLOG(3) << "Bad response from " << #method << ": " << s; \
198 } \
199 call->SendResponse(ToGrpcStatus(s)); \
200 }; \
201 if ((may_block_on_compute_pool)) { \
202 worker_->env()->env->SchedClosure(std::move(closure)); \
203 } else { \
204 worker_->env()->compute_pool->Schedule(std::move(closure)); \
205 } \
206 ENQUEUE_REQUEST(method, false); \
207 }
208
209 HANDLE_CALL(GetStatus, false);
210 HANDLE_CALL(CreateWorkerSession, false);
211 HANDLE_CALL(DeleteWorkerSession, true);
212 HANDLE_CALL(CleanupAll, false);
213 HANDLE_CALL(RegisterGraph, false);
214 HANDLE_CALL(DeregisterGraph, false);
215 HANDLE_CALL(CleanupGraph, false);
216 HANDLE_CALL(Logging, false);
217 HANDLE_CALL(Tracing, false);
218
219 #undef HANDLE_CALL
220
GetStepSequenceHandler(WorkerCall<GetStepSequenceRequest,GetStepSequenceResponse> * call)221 void GetStepSequenceHandler(
222 WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
223 Schedule([this, call]() {
224 worker_->GetStepSequenceAsync(
225 &call->request, &call->response, [call](const Status& s) {
226 VLOG(3) << "Bad response from GetStepSequence:" << s;
227 call->SendResponse(ToGrpcStatus(s));
228 });
229 });
230 ENQUEUE_REQUEST(GetStepSequence, true);
231 }
232
MarkRecvFinishedHandler(WorkerCall<MarkRecvFinishedRequest,MarkRecvFinishedResponse> * call)233 void MarkRecvFinishedHandler(
234 WorkerCall<MarkRecvFinishedRequest, MarkRecvFinishedResponse>* call) {
235 VLOG(3) << "Clean cache entry for request " << call->request.request_id();
236 worker_->RemoveCacheEntryForId(call->request.request_id());
237 call->SendResponse(::grpc::Status::OK);
238 ENQUEUE_REQUEST(MarkRecvFinished, false);
239 }
240
RunGraphHandler(WorkerCall<RunGraphRequest,RunGraphResponse> * call)241 void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
242 Schedule([this, call]() {
243 CallOptions* call_opts = new CallOptions;
244 ProtoRunGraphRequest* wrapped_request =
245 new ProtoRunGraphRequest(&call->request);
246 NonOwnedProtoRunGraphResponse* wrapped_response =
247 new NonOwnedProtoRunGraphResponse(&call->response);
248 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
249 worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
250 [call, call_opts, wrapped_request,
251 wrapped_response](const Status& s) {
252 VLOG(3) << "RunGraph::Done";
253 if (!s.ok()) {
254 VLOG(3) << "Bad response from RunGraph:" << s;
255 }
256 call->ClearCancelCallback();
257 delete call_opts;
258 delete wrapped_request;
259 delete wrapped_response;
260 call->SendResponse(ToGrpcStatus(s));
261 });
262 });
263 ENQUEUE_REQUEST(RunGraph, true);
264 }
265
RecvTensorHandlerRaw(WorkerCall<RecvTensorRequest,::grpc::ByteBuffer> * call)266 void RecvTensorHandlerRaw(
267 WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
268 Schedule([this, call]() {
269 CallOptions* call_opts = new CallOptions;
270 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
271
272 worker_->GrpcRecvTensorAsync(
273 call_opts, &call->request, &call->response,
274 [call, call_opts](const Status& s) {
275 call->ClearCancelCallback();
276 delete call_opts;
277 if (!s.ok()) {
278 VLOG(3) << "Bad response from RecvTensor:" << s;
279 }
280 call->SendResponse(ToGrpcStatus(s));
281 });
282 });
283 EnqueueRecvTensorRequestRaw();
284 }
285
RecvBufHandler(WorkerCall<RecvBufRequest,RecvBufResponse> * call)286 void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
287 Schedule([this, call]() {
288 CallOptions* call_opts = new CallOptions;
289 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
290 worker_->RecvBufAsync(call_opts, &call->request, &call->response,
291 [call, call_opts](const Status& s) {
292 call->ClearCancelCallback();
293 delete call_opts;
294 if (!s.ok()) {
295 VLOG(3) << "Bad response from RecvBuf:" << s;
296 }
297 call->SendResponse(ToGrpcStatus(s));
298 });
299 });
300 ENQUEUE_REQUEST(RecvBuf, true);
301 }
302
CompleteGroupHandler(WorkerCall<CompleteGroupRequest,CompleteGroupResponse> * call)303 void CompleteGroupHandler(
304 WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
305 Schedule([this, call]() {
306 CallOptions* call_opts = new CallOptions;
307 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
308 worker_->CompleteGroupAsync(
309 call_opts, &call->request, &call->response,
310 [call, call_opts](const Status& s) {
311 call->ClearCancelCallback();
312 delete call_opts;
313 if (!s.ok()) {
314 VLOG(3) << "Bad response from CompleteGroup:" << s;
315 }
316 call->SendResponse(ToGrpcStatus(s));
317 });
318 });
319 ENQUEUE_REQUEST(CompleteGroup, true);
320 }
321
CompleteInstanceHandler(WorkerCall<CompleteInstanceRequest,CompleteInstanceResponse> * call)322 void CompleteInstanceHandler(
323 WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
324 Schedule([this, call]() {
325 CallOptions* call_opts = new CallOptions;
326 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
327 worker_->CompleteInstanceAsync(
328 call_opts, &call->request, &call->response,
329 [call, call_opts](const Status& s) {
330 call->ClearCancelCallback();
331 delete call_opts;
332 if (!s.ok()) {
333 VLOG(3) << "Bad response from CompleteInstance:" << s;
334 }
335 call->SendResponse(ToGrpcStatus(s));
336 });
337 });
338 ENQUEUE_REQUEST(CompleteInstance, false);
339 }
340 #undef ENQUEUE_REQUEST
341
EnqueueRecvTensorRequestRaw()342 void EnqueueRecvTensorRequestRaw() {
343 mutex_lock l(shutdown_mu_);
344 if (!is_shutdown_) {
345 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
346 RecvTensorRequest, ::grpc::ByteBuffer>::
347 EnqueueRequestForMethod(
348 worker_service_, cq_.get(),
349 static_cast<int>(GrpcWorkerMethod::kRecvTensor),
350 &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
351 true /* supports cancel*/);
352 }
353 }
354
355 GrpcWorker* const worker_ = nullptr; // Not owned.
356 std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
357 std::unique_ptr<Thread> thread_;
358 std::unordered_map<int, int> queue_depth_;
359 GrpcResponseCache* cache_;
360 grpc::WorkerService::AsyncService* const worker_service_;
361
362 mutex shutdown_mu_;
363 bool is_shutdown_ TF_GUARDED_BY(shutdown_mu_);
364 TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
365 };
366
367 class GrpcWorkerService : public AsyncServiceInterface {
368 public:
GrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)369 GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
370 GrpcWorkerServiceOptions options)
371 : is_shutdown_(false) {
372 builder->RegisterService(&worker_service_);
373
374 for (int i = 0; i < options.num_serving_threads; i++) {
375 threads_.emplace_back(
376 new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
377 cache_.get(), &worker_service_));
378 }
379 }
380
Shutdown()381 void Shutdown() override {
382 bool did_shutdown = false;
383 {
384 mutex_lock l(service_shutdown_mu_);
385 if (!is_shutdown_) {
386 LOG(INFO) << "Shutting down GrpcWorkerService.";
387 is_shutdown_ = true;
388 did_shutdown = true;
389 }
390 }
391 if (did_shutdown) {
392 for (auto& worker_thread : threads_) {
393 worker_thread->Shutdown();
394 }
395 }
396 }
397
398 // This method blocks forever handling requests from the completion queue.
HandleRPCsLoop()399 void HandleRPCsLoop() override {
400 for (auto& worker_thread : threads_) {
401 worker_thread->Start();
402 }
403 for (auto& worker_thread : threads_) {
404 worker_thread->Join();
405 }
406 }
407
408 private:
409 grpc::WorkerService::AsyncService worker_service_;
410 std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
411
412 std::unique_ptr<GrpcResponseCache> cache_;
413 mutex service_shutdown_mu_;
414 bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);
415
416 TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
417 };
418
419 } // namespace
420
GrpcWorker(WorkerEnv * worker_env,const ConfigProto & config)421 GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
422 : Worker(worker_env),
423 recv_buf_max_chunk_(
424 config.experimental().recv_buf_max_chunk() > 0
425 ? config.experimental().recv_buf_max_chunk()
426 : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {
427 if (config.rpc_options().cache_rpc_response()) {
428 EnableResponseCache();
429 }
430 }
431
EnableResponseCache()432 void GrpcWorker::EnableResponseCache() {
433 VLOG(3) << "Enabling gRPC tensor response cache.";
434 response_cache_ = absl::make_unique<GrpcResponseCache>();
435 }
436
437 // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
438 // buffers for a response object, to avoid extra protocol buffer serialization
439 // overhead we generate our response directly into a ::grpc::ByteBuffer object
GrpcRecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,::grpc::ByteBuffer * response,StatusCallback done)440 void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
441 const RecvTensorRequest* request,
442 ::grpc::ByteBuffer* response,
443 StatusCallback done) {
444 VLOG(3) << "GrpcRecvTensorAsync req: " << request->DebugString();
445 const int64_t request_id = request->request_id();
446 const int64_t step_id = request->step_id();
447
448 bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
449
450 auto do_response = [response, done, cache_enabled](const Tensor& tensor,
451 bool is_dead,
452 const Status& status) {
453 if (status.ok()) {
454 grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
455 }
456 done(status);
457 };
458
459 // If response cache is enabled and the response cache already contains the
460 // request, we delegate this retry request to the response cache. Otherwise,
461 // we add the request to the response cache and start the computation to
462 // retrieve the requested data.
463 if (cache_enabled &&
464 response_cache_->QueueRequest(request_id, step_id, do_response)) {
465 return;
466 }
467
468 auto rendezvous_done = [this, request_id, do_response, cache_enabled](
469 const Tensor& tensor, bool is_dead,
470 const Status& status) {
471 if (cache_enabled) {
472 // Data is ready. Process all pending requests in the response cache.
473 response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
474 } else {
475 do_response(tensor, is_dead, status);
476 }
477 };
478
479 auto fail = [&rendezvous_done](const Status& status) {
480 rendezvous_done(Tensor(), false, status);
481 };
482
483 Status s = recent_request_ids_.TrackUnique(
484 request_id, "RecvTensor (GrpcWorker)", *request);
485 if (!s.ok()) {
486 fail(s);
487 return;
488 }
489
490 const string& key = request->rendezvous_key();
491 TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
492 Rendezvous::ParsedKey parsed;
493 s = Rendezvous::ParseKey(key, &parsed);
494 Device* src_dev = nullptr;
495 if (s.ok()) {
496 s = PrepareRecvTensor(parsed, &src_dev);
497 }
498 if (!s.ok()) {
499 fail(s);
500 return;
501 }
502
503 // Request the tensor associated with the rendezvous key.
504 // Note that we log the cancellation here but do not abort the current step.
505 // gRPC can generate cancellations in response to transient network failures,
506 // and aborting the step eliminates the opportunity for client side retries.
507 // Repeated client failures will eventually cause the step to be aborted by
508 // the client.
509 opts->SetCancelCallback(
510 [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
511 env_->rendezvous_mgr->RecvLocalAsync(
512 step_id, parsed,
513 [opts, rendezvous_done, src_dev, request](
514 const Status& status, const Rendezvous::Args& send_args,
515 const Rendezvous::Args& recv_args, const Tensor& val,
516 const bool is_dead) {
517 opts->ClearCancelCallback();
518 if (status.ok()) {
519 // DMA can only be used for Tensors that do not fall into
520 // the following three odd edge cases: 1) a zero-size
521 // buffer, 2) a dead tensor which has an uninit value, and
522 // 3) the tensor has the on_host allocation attribute,
523 // i.e. it's in CPU RAM *independent of its assigned
524 // device type*.
525 const bool on_host = send_args.alloc_attrs.on_host();
526 {
527 // Non-DMA cases.
528 if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
529 DeviceContext* send_dev_context = send_args.device_context;
530 AllocatorAttributes alloc_attrs;
531 alloc_attrs.set_gpu_compatible(true);
532 alloc_attrs.set_on_host(true);
533 Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
534 Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
535 CHECK(send_dev_context)
536 << "send dev name: " << src_dev->name()
537 << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
538 // "val" is on an accelerator device. Uses the device_context to
539 // fill the copy on host.
540 StatusCallback copy_ready = [rendezvous_done, copy,
541 is_dead](const Status& s) {
542 // The value is now ready to be returned on the wire.
543 rendezvous_done(*copy, is_dead, s);
544 delete copy;
545 };
546
547 CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
548 src_dev, copy, send_dev_context, copy_ready);
549 return;
550 }
551 }
552 }
553
554 rendezvous_done(val, is_dead, status);
555 });
556 }
557
558 namespace {
559 // If RecvBufRespExtra.tensor_content is a single large string, then gRPC
560 // can stall on the recv side when the string buffer needs to be enlarged,
561 // since the size is not sent in advance. Changing this field to a sequence
562 // of small strings costs some extra time on the send side, since we do
563 // some otherwise unnecessary copies, but it improves runtime overall by
564 // improving flow control. Best performance is likely achieved with a
565 // max_chunk_bytes equal to the memory page size.
566 //
567 // TODO(tucker): When proto3 supports [ctype=CORD] then change
568 // RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
569 // and remove this function.
SetTensorInRecvBufResp(int64_t max_chunk_bytes,const Tensor * tensor,RecvBufResponse * response)570 void SetTensorInRecvBufResp(int64_t max_chunk_bytes, const Tensor* tensor,
571 RecvBufResponse* response) {
572 RecvBufRespExtra extra;
573 int64_t num_bytes = tensor->TotalBytes();
574 const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
575 while (num_bytes > 0) {
576 int64_t bytes =
577 max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
578 extra.add_tensor_content(std::string(head, bytes));
579 head += bytes;
580 num_bytes -= bytes;
581 }
582 response->mutable_transport_options()->PackFrom(extra);
583 }
584 } // namespace
585
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)586 void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
587 RecvBufResponse* response, StatusCallback done) {
588 const int64_t request_id = request->request_id();
589 const int64_t step_id = request->step_id();
590 bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
591
592 auto do_response = [this, response, done, cache_enabled](
593 const Tensor& tensor, bool is_dead,
594 const Status& status) {
595 if (status.ok()) {
596 SetTensorInRecvBufResp(recv_buf_max_chunk_, &tensor, response);
597 }
598 response->set_send_start_micros(env_->env->NowMicros());
599 response->set_require_ack(cache_enabled);
600 done(status);
601 };
602
603 // If response cache is enabled and the response cache already contains the
604 // request, we delegate this retry request to the response cache. Otherwise,
605 // we add the request to the response cache and start the computation to
606 // retrieve the requested data.
607 if (cache_enabled &&
608 response_cache_->QueueRequest(request_id, step_id, do_response)) {
609 return;
610 }
611
612 auto rendezvous_done = [this, request_id, do_response, cache_enabled](
613 const Tensor& tensor, const Status& status) {
614 if (cache_enabled) {
615 // Data is ready. Process all pending requests in the response cache.
616 response_cache_->OnRequestFinished(request_id, tensor, false, status);
617 } else {
618 do_response(tensor, false, status);
619 }
620 };
621
622 auto fail = [&rendezvous_done](const Status& status) {
623 rendezvous_done(Tensor(), status);
624 };
625
626 // This is a generic, low performance implementation appropriate for grpc.
627 Status s = recent_request_ids_.TrackUnique(request_id, "RecvBuf (GrpcWorker)",
628 *request);
629 if (!s.ok()) {
630 fail(s);
631 return;
632 }
633
634 CollectiveExecutor::Handle ce_handle(
635 env_->collective_executor_mgr->FindOrCreate(step_id), true);
636 CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
637 auto consumer_callback = [this, request, rendezvous_done](
638 const Status& status,
639 BufRendezvous::Hook* hook) {
640 Status s = status;
641 if (s.ok()) {
642 if (hook == nullptr) {
643 s = errors::Internal("Invalid null hook for key ",
644 request->buf_rendezvous_key());
645 }
646 if (!DMAHelper::CanUseDMA(hook->prod_value)) {
647 s = errors::Internal("Tensor value for key ",
648 request->buf_rendezvous_key(),
649 " is not of a type supported by RecvBuf");
650 }
651 } else {
652 if (hook != nullptr) {
653 LOG(ERROR) << "Got hook " << hook << " with status " << s
654 << " from ConsumeBuf";
655 }
656 }
657
658 if (s.ok()) {
659 // The RPC source tensor needs to be in CPU RAM. If not already
660 // there make a copy using memory appropriate to the purpose.
661 const size_t num_bytes = hook->prod_value->TotalBytes();
662 const bool on_host =
663 hook->prod_dev->attributes().device_type() == "CPU" ||
664 hook->prod_attr.on_host();
665 if ((!on_host) && (num_bytes > 0)) {
666 Device* cpu_dev = nullptr;
667 s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
668 if (s.ok()) {
669 AllocatorAttributes cpu_attr;
670 cpu_attr.set_gpu_compatible(true);
671 cpu_attr.set_nic_compatible(true);
672 ScopedMemoryDebugAnnotation op_annotation(
673 "GrpcWorker::RecvBufAsync::consumer_callback", request->step_id(),
674 "dynamic", hook->prod_value->dtype(), &hook->prod_value->shape());
675 Tensor* cpu_tensor =
676 new Tensor(cpu_dev->GetAllocator(cpu_attr),
677 hook->prod_value->dtype(), hook->prod_value->shape());
678 hook->prod_ctx->CopyDeviceTensorToCPU(
679 hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
680 [hook, cpu_tensor, rendezvous_done](const Status& s) {
681 rendezvous_done(*cpu_tensor, s);
682 BufRendezvous::DoneWithHook(hook);
683 delete cpu_tensor;
684 });
685 return;
686 }
687 }
688 }
689
690 if (hook == nullptr) {
691 rendezvous_done(Tensor(), s);
692 } else {
693 rendezvous_done(*hook->prod_value, s);
694 BufRendezvous::DoneWithHook(hook);
695 }
696 };
697 rma->buf_rendezvous()->ConsumeBuf(
698 request->buf_rendezvous_key(), request->src_device(),
699 request->src_incarnation(), consumer_callback,
700 /*cancellation_manager=*/nullptr);
701 }
702
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)703 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
704 LoggingResponse* response, StatusCallback done) {
705 auto env = this->env();
706 if (env) {
707 auto session_mgr = env->session_mgr;
708 if (session_mgr) {
709 if (request->enable_rpc_logging()) {
710 session_mgr->SetLogging(true);
711 }
712 // NOTE(mrry): Handle old masters that disable RPC logging by setting
713 // `request->enable_rpc_logging` to `false`.
714 if (request->disable_rpc_logging() ||
715 (!request->enable_rpc_logging() &&
716 request->fetch_step_id_size() == 0)) {
717 session_mgr->SetLogging(false);
718 }
719 for (const auto& step_id : request->fetch_step_id()) {
720 session_mgr->RetrieveLogs(step_id, response);
721 }
722 if (request->clear()) {
723 session_mgr->ClearLogs();
724 }
725 }
726 }
727 done(Status::OK());
728 }
729
CleanupGraphAsync(const CleanupGraphRequest * request,CleanupGraphResponse * response,StatusCallback done)730 void GrpcWorker::CleanupGraphAsync(const CleanupGraphRequest* request,
731 CleanupGraphResponse* response,
732 StatusCallback done) {
733 if (response_cache_) {
734 // Cleanup any stale response cache entries for this step. This can occur if
735 // a worker crashes before acking a request.
736 response_cache_->CleanEntriesForStep(request->step_id());
737 }
738 Worker::CleanupGraphAsync(request, response, done);
739 }
740
env()741 WorkerEnv* GrpcWorker::env() { return env_; }
742
RemoveCacheEntryForId(int64_t request_id)743 void GrpcWorker::RemoveCacheEntryForId(int64_t request_id) {
744 if (response_cache_) {
745 response_cache_->EraseRequestId(request_id);
746 }
747 }
748
NewGrpcWorker(WorkerEnv * env,const ConfigProto & config)749 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
750 const ConfigProto& config) {
751 return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
752 }
753
NewGrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)754 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
755 GrpcWorker* worker, ::grpc::ServerBuilder* builder,
756 GrpcWorkerServiceOptions options) {
757 return std::unique_ptr<AsyncServiceInterface>(
758 new GrpcWorkerService(worker, builder, options));
759 }
760
761 } // namespace tensorflow
762