1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
17
18 #include <deque>
19 #include <memory>
20 #include <unordered_map>
21 #include <vector>
22
23 #include "grpcpp/alarm.h"
24 #include "grpcpp/server_builder.h"
25
26 #include "absl/container/flat_hash_map.h"
27 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
28 #include "tensorflow/core/common_runtime/device.h"
29 #include "tensorflow/core/common_runtime/device_mgr.h"
30 #include "tensorflow/core/common_runtime/dma_helper.h"
31 #include "tensorflow/core/common_runtime/local_device.h"
32 #include "tensorflow/core/common_runtime/process_util.h"
33 #include "tensorflow/core/common_runtime/step_stats_collector.h"
34 #include "tensorflow/core/distributed_runtime/graph_mgr.h"
35 #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
36 #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
37 #include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
38 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
39 #include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
40 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
41 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
42 #include "tensorflow/core/distributed_runtime/worker.h"
43 #include "tensorflow/core/distributed_runtime/worker_cache.h"
44 #include "tensorflow/core/distributed_runtime/worker_session.h"
45 #include "tensorflow/core/framework/cancellation.h"
46 #include "tensorflow/core/framework/collective.h"
47 #include "tensorflow/core/framework/tensor.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/status.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/lib/strings/strcat.h"
52 #include "tensorflow/core/lib/strings/stringprintf.h"
53 #include "tensorflow/core/platform/logging.h"
54 #include "tensorflow/core/platform/mutex.h"
55 #include "tensorflow/core/platform/tracing.h"
56 #include "tensorflow/core/protobuf/transport_options.pb.h"
57 #include "tensorflow/core/protobuf/worker.pb.h"
58
59 namespace tensorflow {
60
61 namespace {
62
63 // This macro creates a new request for the given RPC method name
64 // (e.g., `ENQUEUE_REQUEST(GetStatus, false);`), and enqueues it on
65 // `this->cq_`.
66 //
67 // This macro is invoked one or more times for each RPC method to
68 // ensure that there are sufficient completion queue entries to
69 // handle incoming requests without blocking.
70 //
71 // The implementation of the request handler for each RPC method
72 // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
73 // to keep accepting new requests.
74 #define ENQUEUE_REQUEST(method, supports_cancel) \
75 do { \
76 mutex_lock l(shutdown_mu_); \
77 if (!is_shutdown_) { \
78 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService, \
79 method##Request, method##Response>:: \
80 EnqueueRequestForMethod( \
81 worker_service_, cq_.get(), \
82 static_cast<int>(GrpcWorkerMethod::k##method), \
83 &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
84 } \
85 } while (0)
86
87 #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \
88 for (int i = 0; \
89 i < gtl::FindWithDefault(queue_depth_, \
90 static_cast<int>(GrpcWorkerMethod::k##method), \
91 default_depth); \
92 ++i) { \
93 ENQUEUE_REQUEST(method, supports_cancel); \
94 }
95
96 // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
97 // requests. Each thread operates on an independent completion queue.
98 class GrpcWorkerServiceThread {
99 public:
GrpcWorkerServiceThread(GrpcWorker * worker,::grpc::ServerBuilder * builder,std::unordered_map<int,int> queue_depth,GrpcResponseCache * cache,grpc::WorkerService::AsyncService * worker_service)100 explicit GrpcWorkerServiceThread(
101 GrpcWorker* worker, ::grpc::ServerBuilder* builder,
102 std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
103 grpc::WorkerService::AsyncService* worker_service)
104 : worker_(worker),
105 queue_depth_(queue_depth),
106 cache_(cache),
107 worker_service_(worker_service),
108 is_shutdown_(false) {
109 cq_ = builder->AddCompletionQueue();
110 }
111
Start()112 void Start() {
113 thread_.reset(
114 worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
115 [this]() { HandleRPCsLoop(); }));
116 }
117
Join()118 void Join() { thread_.reset(); } // Blocks until thread exits
119
Shutdown()120 void Shutdown() {
121 {
122 mutex_lock lock(shutdown_mu_);
123 is_shutdown_ = true;
124 }
125 cq_->Shutdown();
126 }
127
128 private:
129 // Add one or more completion queue entries for each worker method, then
130 // begin servicing requests from the completion queue.
HandleRPCsLoop()131 void HandleRPCsLoop() {
132 // TODO(ncteisen): This may require performance engineering. We can
133 // change the number of threads, the number of handlers per thread,
134 // or even decide to specialize certain threads to certain methods.
135 SETUP_FOR_REQUEST(GetStatus, 1, false);
136 SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
137 SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
138 SETUP_FOR_REQUEST(CleanupAll, 1, false);
139 SETUP_FOR_REQUEST(RegisterGraph, 1, false);
140 SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
141 SETUP_FOR_REQUEST(Logging, 1, false);
142 SETUP_FOR_REQUEST(Tracing, 1, false);
143 SETUP_FOR_REQUEST(CompleteGroup, 10, true);
144 SETUP_FOR_REQUEST(CompleteInstance, 10, true);
145 SETUP_FOR_REQUEST(GetStepSequence, 10, true);
146 SETUP_FOR_REQUEST(RecvBuf, 500, true);
147 SETUP_FOR_REQUEST(RunGraph, 100, true);
148 SETUP_FOR_REQUEST(CleanupGraph, 100, false);
149
150 // TODO(ncteisen): Determine a better policy for enqueuing the
151 // appropriate number of each request type.
152 for (int i = 0;
153 i < gtl::FindWithDefault(
154 queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
155 1000);
156 ++i) {
157 EnqueueRecvTensorRequestRaw();
158 }
159
160 void* tag;
161 bool ok;
162
163 while (cq_->Next(&tag, &ok)) {
164 UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
165 static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
166 CHECK(callback_tag);
167 callback_tag->OnCompleted(this, ok);
168 }
169 }
170
171 private:
Schedule(std::function<void ()> f)172 void Schedule(std::function<void()> f) {
173 worker_->env()->compute_pool->Schedule(std::move(f));
174 }
175
176 // The following section contains one request handler method per
177 // RPC. The `FooHandler` method is called (indirectly) by
178 // `HandleRPCsLoop()` when the next Foo RPC is received. Each
179 // `FooHandler` call schedules a closure on `worker_->env()->compute_pool`,
180 // and is responsible for requesting the next Foo call by calling
181 // `ENQUEUE_REQUEST(Foo)`.
182 template <class RequestMessage, class ResponseMessage>
183 using WorkerCall =
184 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
185 RequestMessage, ResponseMessage>;
186
187 // Handle all non-cancellable simple methods with a standard wrapper.
188 #define HANDLE_CALL(method) \
189 void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
190 Schedule([this, call]() { \
191 Status s = worker_->method(&call->request, &call->response); \
192 if (!s.ok()) { \
193 VLOG(1) << "Bad response from " << #method << ": " << s; \
194 } \
195 call->SendResponse(ToGrpcStatus(s)); \
196 }); \
197 ENQUEUE_REQUEST(method, false); \
198 }
199
200 HANDLE_CALL(GetStatus);
201 HANDLE_CALL(CreateWorkerSession);
202 HANDLE_CALL(DeleteWorkerSession);
203 HANDLE_CALL(CleanupAll);
204 HANDLE_CALL(RegisterGraph);
205 HANDLE_CALL(DeregisterGraph);
206 HANDLE_CALL(CleanupGraph);
207 HANDLE_CALL(Logging);
208 HANDLE_CALL(Tracing);
209
210 #undef HANDLE_CALL
211
GetStepSequenceHandler(WorkerCall<GetStepSequenceRequest,GetStepSequenceResponse> * call)212 void GetStepSequenceHandler(
213 WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
214 Schedule([this, call]() {
215 worker_->GetStepSequenceAsync(
216 &call->request, &call->response, [call](const Status& s) {
217 VLOG(1) << "Bad response from GetStepSequence:" << s;
218 call->SendResponse(ToGrpcStatus(s));
219 });
220 });
221 ENQUEUE_REQUEST(GetStepSequence, true);
222 }
223
RunGraphHandler(WorkerCall<RunGraphRequest,RunGraphResponse> * call)224 void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
225 Schedule([this, call]() {
226 CallOptions* call_opts = new CallOptions;
227 ProtoRunGraphRequest* wrapped_request =
228 new ProtoRunGraphRequest(&call->request);
229 NonOwnedProtoRunGraphResponse* wrapped_response =
230 new NonOwnedProtoRunGraphResponse(&call->response);
231 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
232 auto done_cb = [call, call_opts, wrapped_request,
233 wrapped_response](const Status& s) {
234 VLOG(1) << "RunGraph::Done";
235 if (!s.ok()) {
236 VLOG(1) << "Bad response from RunGraph:" << s;
237 }
238 call->ClearCancelCallback();
239 delete call_opts;
240 delete wrapped_request;
241 delete wrapped_response;
242 call->SendResponse(ToGrpcStatus(s));
243 };
244
245 auto compute_fn = [this, call_opts, wrapped_request,
246 wrapped_response](StatusCallback done) {
247 worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
248 done);
249 };
250
251 if (cache_) {
252 string request_key = call->request.ShortDebugString();
253 cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
254 compute_fn, done_cb);
255 } else {
256 compute_fn(done_cb);
257 }
258 });
259 ENQUEUE_REQUEST(RunGraph, true);
260 }
261
RecvTensorHandlerRaw(WorkerCall<RecvTensorRequest,::grpc::ByteBuffer> * call)262 void RecvTensorHandlerRaw(
263 WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
264 Schedule([this, call]() {
265 CallOptions* call_opts = new CallOptions;
266 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
267
268 auto done_cb = [call, call_opts](const Status& s) {
269 call->ClearCancelCallback();
270 delete call_opts;
271 if (!s.ok()) {
272 VLOG(1) << "Bad response from RecvTensor:" << s;
273 }
274 call->SendResponse(ToGrpcStatus(s));
275 };
276
277 auto compute_fn = [this, &call_opts, &call](StatusCallback done) {
278 worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
279 done);
280 };
281
282 if (cache_) {
283 string request_key = call->request.ShortDebugString();
284 cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
285 compute_fn, done_cb);
286 } else {
287 compute_fn(done_cb);
288 }
289 });
290 EnqueueRecvTensorRequestRaw();
291 }
292
RecvBufHandler(WorkerCall<RecvBufRequest,RecvBufResponse> * call)293 void RecvBufHandler(WorkerCall<RecvBufRequest, RecvBufResponse>* call) {
294 Schedule([this, call]() {
295 CallOptions* call_opts = new CallOptions;
296 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
297 worker_->RecvBufAsync(call_opts, &call->request, &call->response,
298 [call, call_opts](const Status& s) {
299 call->ClearCancelCallback();
300 delete call_opts;
301 if (!s.ok()) {
302 VLOG(1) << "Bad response from RecvBuf:" << s;
303 }
304 call->SendResponse(ToGrpcStatus(s));
305 });
306 });
307 ENQUEUE_REQUEST(RecvBuf, true);
308 }
309
CompleteGroupHandler(WorkerCall<CompleteGroupRequest,CompleteGroupResponse> * call)310 void CompleteGroupHandler(
311 WorkerCall<CompleteGroupRequest, CompleteGroupResponse>* call) {
312 Schedule([this, call]() {
313 CallOptions* call_opts = new CallOptions;
314 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
315 worker_->CompleteGroupAsync(
316 call_opts, &call->request, &call->response,
317 [call, call_opts](const Status& s) {
318 call->ClearCancelCallback();
319 delete call_opts;
320 if (!s.ok()) {
321 VLOG(1) << "Bad response from CompleteGroup:" << s;
322 }
323 call->SendResponse(ToGrpcStatus(s));
324 });
325 });
326 ENQUEUE_REQUEST(CompleteGroup, true);
327 }
328
CompleteInstanceHandler(WorkerCall<CompleteInstanceRequest,CompleteInstanceResponse> * call)329 void CompleteInstanceHandler(
330 WorkerCall<CompleteInstanceRequest, CompleteInstanceResponse>* call) {
331 Schedule([this, call]() {
332 CallOptions* call_opts = new CallOptions;
333 call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
334 worker_->CompleteInstanceAsync(
335 call_opts, &call->request, &call->response,
336 [call, call_opts](const Status& s) {
337 call->ClearCancelCallback();
338 delete call_opts;
339 if (!s.ok()) {
340 VLOG(1) << "Bad response from CompleteInstance:" << s;
341 }
342 call->SendResponse(ToGrpcStatus(s));
343 });
344 });
345 ENQUEUE_REQUEST(CompleteInstance, false);
346 }
347 #undef ENQUEUE_REQUEST
348
EnqueueRecvTensorRequestRaw()349 void EnqueueRecvTensorRequestRaw() {
350 mutex_lock l(shutdown_mu_);
351 if (!is_shutdown_) {
352 Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
353 RecvTensorRequest, ::grpc::ByteBuffer>::
354 EnqueueRequestForMethod(
355 worker_service_, cq_.get(),
356 static_cast<int>(GrpcWorkerMethod::kRecvTensor),
357 &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
358 true /* supports cancel*/);
359 }
360 }
361
362 GrpcWorker* const worker_ = nullptr; // Not owned.
363 std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
364 std::unique_ptr<Thread> thread_;
365 std::unordered_map<int, int> queue_depth_;
366 GrpcResponseCache* cache_;
367 grpc::WorkerService::AsyncService* const worker_service_;
368
369 mutex shutdown_mu_;
370 bool is_shutdown_ GUARDED_BY(shutdown_mu_);
371 TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerServiceThread);
372 };
373
374 class GrpcWorkerService : public AsyncServiceInterface {
375 public:
GrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)376 GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
377 GrpcWorkerServiceOptions options)
378 : is_shutdown_(false) {
379 builder->RegisterService(&worker_service_);
380 if (options.response_cache_bytes > 0) {
381 cache_.reset(
382 new GrpcResponseCache(options.response_cache_bytes,
383 options.response_cache_expires_seconds));
384 }
385
386 for (int i = 0; i < options.num_serving_threads; i++) {
387 threads_.emplace_back(
388 new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
389 cache_.get(), &worker_service_));
390 }
391 }
392
Shutdown()393 void Shutdown() override {
394 bool did_shutdown = false;
395 {
396 mutex_lock l(service_shutdown_mu_);
397 if (!is_shutdown_) {
398 LOG(INFO) << "Shutting down GrpcWorkerService.";
399 is_shutdown_ = true;
400 did_shutdown = true;
401 }
402 }
403 if (did_shutdown) {
404 for (auto& worker_thread : threads_) {
405 worker_thread->Shutdown();
406 }
407 }
408 }
409
410 // This method blocks forever handling requests from the completion queue.
HandleRPCsLoop()411 void HandleRPCsLoop() override {
412 for (auto& worker_thread : threads_) {
413 worker_thread->Start();
414 }
415 for (auto& worker_thread : threads_) {
416 worker_thread->Join();
417 }
418 }
419
420 private:
421 grpc::WorkerService::AsyncService worker_service_;
422 std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
423
424 std::unique_ptr<GrpcResponseCache> cache_;
425 mutex service_shutdown_mu_;
426 bool is_shutdown_ GUARDED_BY(service_shutdown_mu_);
427
428 TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
429 };
430
431 } // namespace
432
GrpcWorker(WorkerEnv * worker_env,const ConfigProto & config)433 GrpcWorker::GrpcWorker(WorkerEnv* worker_env, const ConfigProto& config)
434 : Worker(worker_env),
435 recent_request_ids_(100000),
436 recv_buf_max_chunk_(
437 config.experimental().recv_buf_max_chunk() > 0
438 ? config.experimental().recv_buf_max_chunk()
439 : (config.experimental().recv_buf_max_chunk() < 0 ? 0 : 4096)) {}
440
441 // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
442 // buffers for a response object, to avoid extra protocol buffer serialization
443 // overhead we generate our response directly into a ::grpc::ByteBuffer object
GrpcRecvTensorAsync(CallOptions * opts,const RecvTensorRequest * request,::grpc::ByteBuffer * response,StatusCallback done)444 void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
445 const RecvTensorRequest* request,
446 ::grpc::ByteBuffer* response,
447 StatusCallback done) {
448 Status s = recent_request_ids_.TrackUnique(
449 request->request_id(), "RecvTensor (GrpcWorker)", *request);
450 if (!s.ok()) {
451 done(s);
452 return;
453 }
454
455 const int64 step_id = request->step_id();
456 const string& key = request->rendezvous_key();
457 TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
458 Rendezvous::ParsedKey parsed;
459 s = Rendezvous::ParseKey(key, &parsed);
460 Device* src_dev = nullptr;
461 if (s.ok()) {
462 s = PrepareRecvTensor(parsed, &src_dev);
463 }
464 if (!s.ok()) {
465 done(s);
466 return;
467 }
468
469 // Request the tensor associated with the rendezvous key.
470 // Note that we log the cancellation here but do not abort the current step.
471 // gRPC can generate cancellations in response to transient network failures,
472 // and aborting the step eliminates the opportunity for client side retries.
473 // Repeated client failures will eventually cause the step to be aborted by
474 // the client.
475 opts->SetCancelCallback(
476 [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
477 env_->rendezvous_mgr->RecvLocalAsync(
478 step_id, parsed,
479 [opts, response, done, src_dev, request](
480 const Status& status, const Rendezvous::Args& send_args,
481 const Rendezvous::Args& recv_args, const Tensor& val,
482 const bool is_dead) {
483 opts->ClearCancelCallback();
484 if (status.ok()) {
485 // DMA can only be used for Tensors that do not fall into
486 // the following three odd edge cases: 1) a zero-size
487 // buffer, 2) a dead tensor which has an uninit value, and
488 // 3) the tensor has the on_host allocation attribute,
489 // i.e. it's in CPU RAM *independent of its assigned
490 // device type*.
491 const bool on_host = send_args.alloc_attrs.on_host();
492 {
493 // Non-DMA cases.
494 if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
495 DeviceContext* send_dev_context = send_args.device_context;
496 AllocatorAttributes alloc_attrs;
497 alloc_attrs.set_gpu_compatible(true);
498 alloc_attrs.set_on_host(true);
499 Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
500 Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
501 CHECK(send_dev_context)
502 << "send dev name: " << src_dev->name()
503 << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
504 // "val" is on an accelerator device. Uses the device_context to
505 // fill the copy on host.
506 StatusCallback copy_ready = [response, done, copy,
507 is_dead](const Status& s) {
508 // The value is now ready to be returned on the wire.
509 grpc::EncodeTensorToByteBuffer(is_dead, *copy, response);
510 done(s);
511 delete copy;
512 };
513
514 send_dev_context->CopyDeviceTensorToCPU(
515 &val, request->rendezvous_key(), src_dev, copy, copy_ready);
516 } else {
517 grpc::EncodeTensorToByteBuffer(is_dead, val, response);
518 done(Status::OK());
519 }
520 }
521 } else {
522 // !s.ok()
523 done(status);
524 }
525 });
526 }
527
528 namespace {
529 // If RecvBufRespExtra.tensor_content is a single large string, then gRPC
530 // can stall on the recv side when the string buffer needs to be enlarged,
531 // since the size is not sent in advance. Changing this field to a sequence
532 // of small strings costs some extra time on the send side, since we do
533 // some otherwise unnecessary copies, but it improves runtime overall by
534 // improving flow control. Best performance is likely achieved with a
535 // max_chunk_bytes equal to the memory page size.
536 //
537 // TODO(tucker): When proto3 supports [ctype=CORD] then change
538 // RecvBufRespExtra.tensor_content to a cord instead of a repeated string,
539 // and remove this function.
SetTensorInRecvBufResp(int64 max_chunk_bytes,const Tensor * tensor,int64 num_bytes,RecvBufResponse * response)540 void SetTensorInRecvBufResp(int64 max_chunk_bytes, const Tensor* tensor,
541 int64 num_bytes, RecvBufResponse* response) {
542 RecvBufRespExtra extra;
543 const char* head = reinterpret_cast<const char*>(DMAHelper::base(tensor));
544 while (num_bytes > 0) {
545 int64 bytes =
546 max_chunk_bytes > 0 ? std::min(num_bytes, max_chunk_bytes) : num_bytes;
547 extra.add_tensor_content(std::string(head, bytes));
548 head += bytes;
549 num_bytes -= bytes;
550 }
551 response->mutable_transport_options()->PackFrom(extra);
552 }
553 } // namespace
554
RecvBufAsync(CallOptions * opts,const RecvBufRequest * request,RecvBufResponse * response,StatusCallback done)555 void GrpcWorker::RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
556 RecvBufResponse* response, StatusCallback done) {
557 // This is a generic, low performance implementation appropriate for grpc.
558 Status s = recent_request_ids_.TrackUnique(request->request_id(),
559 "RecvBuf (GrpcWorker)", *request);
560 if (!s.ok()) {
561 done(s);
562 return;
563 }
564 CollectiveExecutor::Handle ce_handle(
565 env_->collective_executor_mgr->FindOrCreate(request->step_id()), true);
566 CollectiveRemoteAccess* rma = ce_handle.get()->remote_access();
567 rma->buf_rendezvous()->ConsumeBuf(
568 request->buf_rendezvous_key(),
569 [this, request, response, done](const Status& status,
570 BufRendezvous::Hook* hook) {
571 Status s = status;
572 if (s.ok()) {
573 if (!DMAHelper::CanUseDMA(hook->prod_value)) {
574 s = errors::Internal("Tensor value for key ",
575 request->buf_rendezvous_key(),
576 " is not of a type supported by RecvBuf");
577 }
578 }
579 if (s.ok()) {
580 // The RPC source tensor needs to be in CPU RAM. If not already
581 // there make a copy using memory appropriate to the purpose.
582 const size_t num_bytes = hook->prod_value->TotalBytes();
583 const bool on_host =
584 hook->prod_dev->attributes().device_type() == "CPU" ||
585 hook->prod_attr.on_host();
586 if ((!on_host) && (num_bytes > 0)) {
587 Device* cpu_dev = nullptr;
588 s = env_->device_mgr->LookupDevice("CPU:0", &cpu_dev);
589 if (s.ok()) {
590 AllocatorAttributes cpu_attr;
591 cpu_attr.set_gpu_compatible(true);
592 cpu_attr.set_nic_compatible(true);
593 Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
594 hook->prod_value->dtype(),
595 hook->prod_value->shape());
596 hook->prod_ctx->CopyDeviceTensorToCPU(
597 hook->prod_value, "empty_name", hook->prod_dev, cpu_tensor,
598 [this, num_bytes, response, done, hook,
599 cpu_tensor](const Status& s) {
600 if (s.ok()) {
601 SetTensorInRecvBufResp(recv_buf_max_chunk_, cpu_tensor,
602 num_bytes, response);
603 }
604 response->set_send_start_micros(env_->env->NowMicros());
605 done(s);
606 BufRendezvous::DoneWithHook(hook);
607 delete cpu_tensor;
608 });
609 return;
610 }
611 } else {
612 // Tensor is on CPU.
613 SetTensorInRecvBufResp(recv_buf_max_chunk_, hook->prod_value,
614 num_bytes, response);
615 }
616 }
617 response->set_send_start_micros(env_->env->NowMicros());
618 done(s);
619 BufRendezvous::DoneWithHook(hook);
620 });
621 }
622
LoggingAsync(const LoggingRequest * request,LoggingResponse * response,StatusCallback done)623 void GrpcWorker::LoggingAsync(const LoggingRequest* request,
624 LoggingResponse* response, StatusCallback done) {
625 auto env = this->env();
626 if (env) {
627 auto session_mgr = env->session_mgr;
628 if (session_mgr) {
629 if (request->enable_rpc_logging()) {
630 session_mgr->SetLogging(true);
631 }
632 // NOTE(mrry): Handle old masters that disable RPC logging by setting
633 // `request->enable_rpc_logging` to `false`.
634 if (request->disable_rpc_logging() ||
635 (!request->enable_rpc_logging() &&
636 request->fetch_step_id_size() == 0)) {
637 session_mgr->SetLogging(false);
638 }
639 for (const auto& step_id : request->fetch_step_id()) {
640 session_mgr->RetrieveLogs(step_id, response);
641 }
642 if (request->clear()) {
643 session_mgr->ClearLogs();
644 }
645 }
646 }
647 done(Status::OK());
648 }
649
env()650 WorkerEnv* GrpcWorker::env() { return env_; }
651
NewGrpcWorker(WorkerEnv * env,const ConfigProto & config)652 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* env,
653 const ConfigProto& config) {
654 return std::unique_ptr<GrpcWorker>(new GrpcWorker(env, config));
655 }
656
NewGrpcWorkerService(GrpcWorker * worker,::grpc::ServerBuilder * builder,GrpcWorkerServiceOptions options)657 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
658 GrpcWorker* worker, ::grpc::ServerBuilder* builder,
659 GrpcWorkerServiceOptions options) {
660 return std::unique_ptr<AsyncServiceInterface>(
661 new GrpcWorkerService(worker, builder, options));
662 }
663
664 } // namespace tensorflow
665