1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 18 19 #include <queue> 20 #include <utility> 21 22 #include "grpcpp/generic/generic_stub.h" 23 #include "grpcpp/grpcpp.h" 24 #include "tensorflow/core/distributed_runtime/call_options.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 28 #include "tensorflow/core/lib/core/refcount.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/core/threadpool.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/notification.h" 34 #include "tensorflow/core/util/env_var.h" 35 36 namespace tensorflow { 37 38 // Object allocated per active RPC. 39 // Manage the state of a single asynchronous RPC request. If `max_retries` 40 // is greater than 0, the request will be retried for any transient failures. 41 template <class Response> 42 class RPCState : public GrpcClientCQTag { 43 public: 44 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 45 const ::grpc::string& method, const protobuf::Message& request, 46 Response* response, StatusCallback done, CallOptions* call_opts, 47 thread::ThreadPool* threadpool, int32 max_retries = 0, 48 bool fail_fast = true) 49 : RPCState( 50 stub, cq, method, request, response, std::move(done), call_opts, 51 threadpool, 52 // 1) If GRPC_FAIL_FAST is specified, fail_fast=$GRPC_FAIL_FAST. 53 // See b/141948186. 54 // 2) Otherwise, if the platform is Google, use the fail_fast from 55 // the caller. See b/140260119. 56 // 3) Otherwise, use fail_fast=false. 57 [fail_fast]() -> bool { 58 bool x; 59 #if defined(PLATFORM_GOOGLE) 60 TF_CHECK_OK(ReadBoolFromEnvVar("GRPC_FAIL_FAST", fail_fast, &x)); 61 #else 62 TF_CHECK_OK(ReadBoolFromEnvVar("GRPC_FAIL_FAST", false, &x)); 63 #endif // PLATFORM_GOOGLE 64 return x; 65 }(), 66 /*timeout_in_ms=*/0, max_retries) { 67 } 68 69 template <typename Request> RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64 timeout_in_ms,int32 max_retries)70 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 71 const ::grpc::string& method, const Request& request, 72 Response* response, StatusCallback done, CallOptions* call_opts, 73 thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms, 74 int32 max_retries) 75 : call_opts_(call_opts), 76 threadpool_(threadpool), 77 done_(std::move(done)), 78 timeout_in_ms_(timeout_in_ms), 79 max_retries_(max_retries), 80 cq_(cq), 81 stub_(stub), 82 method_(method), 83 fail_fast_(fail_fast) { 84 response_ = response; 85 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_); 86 if (!s.ok()) { 87 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 88 << s.error_message(); 89 // Skip retry logic if we fail to parse our request. 90 done_(FromGrpcStatus(s)); 91 delete this; 92 return; 93 } 94 StartCall(); 95 } 96 StartCall()97 void StartCall() { 98 context_.reset(new ::grpc::ClientContext()); 99 context_->set_wait_for_ready(!fail_fast_); 100 if (timeout_in_ms_ > 0) { 101 context_->set_deadline( 102 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN)); 103 } 104 if (call_opts_) { 105 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); }); 106 } 107 108 VLOG(2) << "Starting call: " << method_; 109 110 call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_); 111 call_->StartCall(); 112 call_->Finish(&response_buf_, &status_, this); 113 } 114 OnCompleted(bool ok)115 void OnCompleted(bool ok) override { 116 if (call_opts_) { 117 call_opts_->ClearCancelCallback(); 118 } 119 120 VLOG(2) << "Completed call: " << method_; 121 122 Status s = FromGrpcStatus(status_); 123 if (s.ok() && !ok) { 124 // Since this function is only being used for processing the response 125 // to Finish for client-side unary calls, ok should never be false 126 s.Update( 127 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 128 "not. This should never happen.")); 129 } 130 131 if (s.ok()) { 132 if (threadpool_) { 133 // Run parse and callback in another thread, returning this 134 // one to service more RPCs. 135 threadpool_->Schedule([this]() { ParseAndCallDone(); }); 136 } else { 137 ParseAndCallDone(); 138 } 139 return; 140 } 141 142 VLOG(1) << method_ << " returned with non-ok status: " << s 143 << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n" 144 << context_->debug_error_string(); 145 // Retry if we have any attempts left 146 if (++num_retries_ <= max_retries_ && 147 (errors::IsUnavailable(s) || errors::IsUnknown(s))) { 148 response_buf_.Clear(); 149 VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_ 150 << " of " << max_retries_; 151 // TODO(b/139945426) Allow user to configure the retry backoff time. 152 StartCall(); 153 } else { 154 // Attach additional GRPC error information if any to the final status 155 s = Status(s.code(), 156 strings::StrCat(s.error_message(), 157 "\nAdditional GRPC error information:\n", 158 context_->debug_error_string())); 159 // Always treat gRPC cancellation as a derived error. This ensures that 160 // other error types are preferred during status aggregation. (gRPC 161 // cancellation messages do not contain the original status message). 162 if (s.code() == tensorflow::error::Code::CANCELLED) { 163 s = StatusGroup::MakeDerived(s); 164 } 165 166 done_(s); 167 delete this; 168 } 169 } 170 ParseAndCallDone()171 void ParseAndCallDone() { 172 Status s; 173 if (!GrpcMaybeParseProto(&response_buf_, response_)) { 174 s.Update(errors::Internal("could not parse rpc response")); 175 } 176 done_(s); 177 delete this; 178 } 179 180 private: 181 CallOptions* call_opts_; 182 std::unique_ptr<::grpc::ClientContext> context_; 183 thread::ThreadPool* threadpool_; 184 std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; 185 Response* response_; 186 ::grpc::ByteBuffer request_buf_; 187 ::grpc::ByteBuffer response_buf_; 188 ::grpc::Status status_; 189 StatusCallback done_; 190 int64 timeout_in_ms_; 191 192 size_t num_retries_ = 0; 193 size_t max_retries_; 194 195 ::grpc::CompletionQueue* cq_; 196 ::grpc::GenericStub* stub_; 197 ::grpc::string method_; 198 bool fail_fast_; 199 }; 200 201 // Represents state associated with one streaming RPC call. 202 // Similarly to above, we extract the methods of StreamingRPCState that don't 203 // need to be templated into this abstract class. 204 // Currently, *StreamingRPCState does not support client closing the call as 205 // there is no use case for it - current clients keep the streaming call open 206 // as long as possible. If/when the need arises, support can be added 207 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag 208 // TagType::kClientFinished and handling the completion in a new callback. 209 class UntypedStreamingRPCState : public core::RefCounted { 210 public: 211 virtual void CallStarted(bool ok) = 0; 212 virtual void RequestWriteCompleted(bool ok) = 0; 213 virtual void ResponseReadCompleted(bool ok) = 0; 214 virtual void CallFinished(bool ok) = 0; 215 216 virtual string DebugString() const = 0; 217 218 class Tag : public GrpcClientCQTag { 219 public: 220 // One enum value per supported callback. 221 enum class TagType { 222 kCallStarted, 223 kRequestWriteCompleted, 224 kResponseReadCommpleted, 225 kCallFinished, 226 }; 227 228 Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type); 229 230 // Calls the callback associated with this tag and Unrefs 231 // `this->streaming_state_`. 232 void OnCompleted(bool ok) override; 233 234 private: 235 // OnCompleted() consumes on reference each time it is called. 236 UntypedStreamingRPCState* const streaming_state_; 237 const Tag::TagType type_; 238 }; 239 }; 240 241 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type); 242 243 // Represents a single request/response exchange between client and the server. 244 // A single streaming call contains a sequence of exchanges. Besides the 245 // messages, exchange contains: 246 // - the user callback to invoke when exchange completes (response is received 247 // or an error occurs). 248 // - The current state of the exchange. 249 class Exchange { 250 public: 251 enum class State { 252 kExchangeCreated, 253 kRequestWriteIssued, 254 kRequestWriteCompleted, 255 kResponseReadIssued, 256 }; 257 Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)258 Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response, 259 StatusCallback cb, string debug_string) 260 : state_(State::kExchangeCreated), 261 request_buf_(request_buf), 262 response_(response), 263 cb_(std::move(cb)), 264 debug_string_(std::move(debug_string)) {} 265 request_buf()266 const ::grpc::ByteBuffer& request_buf() { return request_buf_; } response_buf()267 ::grpc::ByteBuffer* response_buf() { return &response_buf_; } 268 MarkRequestWriteIssued()269 void MarkRequestWriteIssued() { 270 DCHECK(state_ == State::kExchangeCreated); 271 state_ = State::kRequestWriteIssued; 272 } MarkRequestWriteCompleted()273 void MarkRequestWriteCompleted() { 274 DCHECK(state_ == State::kRequestWriteIssued); 275 state_ = State::kRequestWriteCompleted; 276 } MarkResponseReadIssued()277 void MarkResponseReadIssued() { 278 DCHECK(state_ == State::kRequestWriteCompleted); 279 state_ = State::kResponseReadIssued; 280 } 281 282 // If `status` is success, completes this exchange by parsing the 283 // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the 284 // callback with `status`. 285 void Complete(Status status); 286 state()287 const State& state() const { return state_; } 288 289 string DebugString() const; 290 291 private: 292 State state_; 293 ::grpc::ByteBuffer request_buf_; 294 ::grpc::ByteBuffer response_buf_; 295 protobuf::Message* response_; 296 StatusCallback cb_; 297 string debug_string_; 298 }; 299 300 const char* ToString(Exchange::State s); 301 302 std::ostream& operator<<(std::ostream& os, const Exchange::State& state); 303 304 // Represents a queue of exchanges. 305 // When a client sends a new request a new exchange is created and added to the 306 // end of the queue. Completed exchanges are popped from the front of the queue. 307 // An explicit exchange queue is needed to brdige the client, which can send new 308 // requests at any time, with gRPC infrastructure, which can handle a single 309 // read and a single write request at a time. 310 // 311 // As the exchange progresses (request sending initiated, request sending 312 // completed, response reading initiated) the queue helps to make sure that the 313 // right operation is issued on the right exchange at the right time. 314 // 315 // To satisfy gRPC constraints, the states of exchanges must be as follows 316 // starting from the front of the queue: 317 // - 0 or 1 exchange in kResponseReadIssued state 318 // - 0 or more exchanges in kRequestWriteCompleted state 319 // - 0 or 1 exchange in kRequestWriteIssued state 320 // - 0 or more exchanges in kExchangeCreated state 321 // 322 // Thread-compatible. 323 class ExchangeQueue { 324 public: 325 // Creates a new exchange and adds it to the end of the queue. 326 void Emplace(const ::grpc::ByteBuffer& request_buf, 327 protobuf::Message* response, StatusCallback cb, 328 std::string debug_string); 329 330 // Returns an exchange for which we can initiate request writing, if any. 331 // Returns nullptr if there is no such exchange. 332 Exchange* GetReadyForRequestWriting(); 333 334 // Returns an exchange for which we can initiate response reading, if any. 335 // Returns nullptr if there is no such exchange. 336 Exchange* GetReadyForResponseReading(); 337 338 // Changes the state of the exchange that is current in kRequestWriteIssued 339 // state to kRequestWriteCompleted state. 340 // REQUIRES: There is an exhange in kRequestWriteIssued state. 341 void MarkRequestWriteCompleted(); 342 343 // Returns the exchange at the front of the queue. 344 // REQUIRES: ExchangeQueue is not empty. 345 Exchange& GetFront(); 346 347 // Removes the exchange at the front of the queue. 348 // REQUIRES: ExchangeQueue is not empty. 349 void PopFront(); 350 351 // Returns a string containing addresses and states of all exchanges in this 352 // queue. 353 string DebugString() const; 354 355 // Swaps the contents of this and `other`. 356 void Swap(ExchangeQueue* other); 357 358 // Completes all exchanges in this with `status`. 359 void CompleteAll(Status status); 360 CallStarted()361 void CallStarted() { call_started_ = true; } 362 363 private: 364 // Does nothing by default. Turn on VLOG(5) to enable. 365 // Checks that this ExchangeQueue is in a valid state. 366 // Kills the process if not. 367 void CheckInvariants(); 368 369 // We can't process any exchanges until the call has started. 370 bool call_started_ = false; 371 372 // std::queue is based on std::deque by default. std::deque provides 373 // fairly strong iterator stability. 374 std::deque<Exchange> exchanges_; 375 }; // namespace tensorflow 376 377 // Represents state associated with one streaming RPC call. 378 // Thread-safe 379 template <class Response> 380 class StreamingRPCState : public UntypedStreamingRPCState { 381 public: 382 // Default behavior is to set fail_fast = False and handle timeouts 383 // manually. StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)384 StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call, 385 const std::shared_ptr<::grpc::ClientContext>& context) 386 : context_(context), call_(std::move(call)), call_state_(State::kActive) { 387 Ref(); 388 VLOG(3) << "Created new StreamingRPCState " << this; 389 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall"; 390 call_->StartCall(&call_started_tag_); 391 } 392 ~StreamingRPCState()393 ~StreamingRPCState() override { 394 VLOG(3) << "Destructing StreamingRPCState " << this; 395 } 396 397 // Attempts to send the next request. `done` is invoked when 398 // `response` has been filled with the data from the server, or if there 399 // is an error. `done` can be invoked before SendNextRequest returns. 400 // Return `true` if the call is alive and the `done` callback has or 401 // will be invoked. If the call is dead, returns `false`. `done` callback 402 // will not be invoked in this case. 403 // REQUIRES: The call has been started, i.e. WaitForCallStarted() has 404 // returned. SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)405 bool SendNextRequest(const protobuf::Message& request, Response* response, 406 const StatusCallback& done) { 407 ::grpc::ByteBuffer request_buf; 408 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf); 409 if (!s.ok()) { 410 Status status = FromGrpcStatus(s); 411 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 412 << status.ToString(); 413 done(status); 414 return true; 415 } 416 417 mutex_lock l(mu_); 418 if (call_state_ != State::kActive) { 419 // `done` is not invoked intentionally. 420 return false; 421 } 422 if (VLOG_IS_ON(3)) { 423 // If vlog 3 is enabled, include first 100 chars of request as debug 424 // string. 425 exchanges_.Emplace(request_buf, response, done, 426 request.ShortDebugString().substr(0, 100)); 427 } else { 428 exchanges_.Emplace(request_buf, response, done, ""); 429 } 430 MaybeIssueRequestWriteLocked(); 431 return true; 432 } 433 CallStarted(bool ok)434 void CallStarted(bool ok) override { 435 VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok 436 << ")"; 437 mutex_lock l(mu_); 438 if (!ok) { 439 call_state_ = State::kDone; 440 return; 441 } 442 exchanges_.CallStarted(); 443 // Now that the call has started, we can write our first request, if any. 444 MaybeIssueRequestWriteLocked(); 445 } 446 RequestWriteCompleted(bool ok)447 void RequestWriteCompleted(bool ok) override { 448 VLOG(3) << "StreamingRPCState(" << this 449 << ")::RequestWriteCompleted(ok=" << ok << ")"; 450 mu_.lock(); 451 if (call_state_ != State::kActive) { 452 mu_.unlock(); 453 return; 454 } 455 if (!ok) { 456 // unlocks mu_ 457 MarkDoneAndCompleteExchanges(errors::Internal( 458 "Not ok value returned by CompletionQueue when attempting streaming " 459 "rpc write. Probably because the completion queue has been shut " 460 "down or the connection went down. ", 461 context_->debug_error_string())); 462 return; 463 } 464 465 exchanges_.MarkRequestWriteCompleted(); 466 MaybeIssueResponseReadLocked(); 467 MaybeIssueRequestWriteLocked(); 468 mu_.unlock(); 469 } 470 ResponseReadCompleted(bool ok)471 void ResponseReadCompleted(bool ok) override { 472 VLOG(3) << "StreamingRPCState(" << this 473 << ")::ResponseReadCompleted(ok=" << ok << ")"; 474 mu_.lock(); 475 if (call_state_ != State::kActive) { 476 mu_.unlock(); 477 return; 478 } 479 if (!ok) { 480 IssueCallFinishLocked(); 481 mu_.unlock(); 482 return; 483 } 484 485 // Complete the exchange without holding the lock because user's 486 // callback can call back into this RPC code resulting in a deadlock. 487 // No other thread can pop this exchange while we release the lock because 488 // this is the only method that pops exchanges and it is called from a 489 // single thread that waits on completion queue events. 490 Exchange* e; 491 e = &exchanges_.GetFront(); 492 mu_.unlock(); 493 494 e->Complete(Status::OK()); 495 496 { 497 mutex_lock l(mu_); 498 exchanges_.PopFront(); 499 MaybeIssueResponseReadLocked(); 500 } 501 } 502 CallFinished(bool ok)503 void CallFinished(bool ok) override { 504 VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok 505 << ")"; 506 mu_.lock(); 507 DCHECK(call_state_ != State::kActive); 508 if (call_state_ != State::kFinishing) { 509 mu_.unlock(); 510 return; 511 } 512 513 Status s = FromGrpcStatus(call_status_); 514 if (s.ok() && !ok) { 515 s.Update( 516 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 517 "not. This should never happen.", 518 context_->debug_error_string())); 519 } 520 // unlocks mu_ 521 MarkDoneAndCompleteExchanges(s); 522 } 523 DebugString()524 string DebugString() const override { 525 mutex_lock l(mu_); 526 return exchanges_.DebugString(); 527 } 528 529 private: 530 enum class State { 531 kActive, 532 kFinishing, 533 kDone, 534 }; 535 MarkDoneAndCompleteExchanges(Status status)536 void MarkDoneAndCompleteExchanges(Status status) EXCLUSIVE_LOCKS_REQUIRED(mu_) 537 UNLOCK_FUNCTION(mu_) { 538 call_state_ = State::kDone; 539 VLOG(2) << "Ending gRPC stremaing call on the client side due to " 540 << status.ToString(); 541 // Swap the exchanges_ into a temporary ExchangeQueue so that we can 542 // complete all exchanges without holding mu_ in case user callback 543 // reach back into this. This should be impossible now, but safer for 544 // the future. 545 ExchangeQueue queue; 546 exchanges_.Swap(&queue); 547 mu_.unlock(); 548 queue.CompleteAll(status); 549 } 550 MaybeIssueRequestWriteLocked()551 void MaybeIssueRequestWriteLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 552 Exchange* exchange = exchanges_.GetReadyForRequestWriting(); 553 if (exchange == nullptr) { 554 // There are no queued exchanges, there is already an outstanding write, 555 // or there are no just created exchanges. 556 return; 557 } 558 exchange->MarkRequestWriteIssued(); 559 Ref(); 560 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write"; 561 call_->Write(exchange->request_buf(), &request_write_completed_tag_); 562 } 563 MaybeIssueResponseReadLocked()564 void MaybeIssueResponseReadLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 565 Exchange* exchange = exchanges_.GetReadyForResponseReading(); 566 if (exchange == nullptr) { 567 return; 568 } 569 exchange->MarkResponseReadIssued(); 570 Ref(); 571 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read"; 572 call_->Read(exchange->response_buf(), &response_read_completed_tag_); 573 } 574 IssueCallFinishLocked()575 void IssueCallFinishLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 576 call_state_ = State::kFinishing; 577 Ref(); 578 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish"; 579 // We call finish in response to completed (with error) response reading tag 580 // on some exchange. We let this exchange hang in ResponseReadIssued state. 581 // ExchangeQueue makes sure that there is at most one exchange in this 582 // state. So, no new reads will be issued. 583 call_->Finish(&call_status_, &finished_tag_); 584 } 585 586 // Holds state for a single request/response exchange between the client 587 // and the server. 588 typedef typename UntypedStreamingRPCState::Tag Tag; 589 590 // Order of context_ and call_ is important because context_ must outlive 591 // call_. 592 const std::shared_ptr<const ::grpc::ClientContext> context_; 593 std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_; 594 595 mutable mutex mu_; 596 ExchangeQueue exchanges_ GUARDED_BY(mu_); 597 State call_state_ GUARDED_BY(mu_); 598 ::grpc::Status call_status_ GUARDED_BY(mu_); 599 600 // We can get away with having single instances of these tags per 601 // StreamingRPCState because we make sure (as gRPC requires) that 602 // there is at most one outstanding Read and at most one outstanding Write 603 // in the completion queue. 604 // Tags are immutable. No need to guard them. 605 Tag call_started_tag_{this, Tag::TagType::kCallStarted}; 606 Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted}; 607 Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCommpleted}; 608 Tag finished_tag_{this, Tag::TagType::kCallFinished}; 609 }; 610 611 // Creates streaming calls and dispatches requests to them. 612 // In the common case, the client would create a StreamingRPCDispatcher for 613 // each bidirectional streaming RPC it might want to make. The first time, it 614 // calls SendNextRequest, a streaming call is initiated and the request is 615 // sent within this call. Initiation of the call blocks the client. If there are 616 // no errors, subsequent calls to SendNextRequest would use the already active 617 // call. If there was an error, the call object will be destroyed after all 618 // the callbacks for outstanding requests have been invoked. The next call to 619 // SendNextRequest will initiate a new call. 620 // 621 // Callbacks that are part of the same call, are invoked in the order they were 622 // provided, but callbacks across calls (a failed and a new one) can be invoked 623 // in any order. 624 // 625 // Thread-safe. 626 template <class Response> 627 class StreamingRPCDispatcher { 628 public: StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)629 StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 630 const ::grpc::string& method) 631 : stub_(stub), cq_(cq), method_(method) {} 632 633 // Attempts to send the next request. If there is no active streaming call, 634 // starts one and sends the request on top of it. `done` is invoked when 635 // `response` has been filled with the data from the server, or if there 636 // is an error. `done` can be invoked before SendNextRequest returns. SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)637 void SendNextRequest(const protobuf::Message& request, Response* response, 638 StatusCallback done) { 639 mutex_lock l(mu_); 640 if (state_ == nullptr) { 641 CreateStreamingState(); 642 } 643 644 bool is_call_alive = state_->SendNextRequest(request, response, done); 645 if (is_call_alive) { 646 return; 647 } 648 649 // The attempt to send failed because the call was dead, create a new 650 // call and try again. When the call is dead SendNextRequest does not call 651 // `done`. 652 CreateStreamingState(); 653 654 is_call_alive = state_->SendNextRequest(request, response, done); 655 if (!is_call_alive) { 656 // Consider retrying to create and start a call few more times. 657 done(errors::Unknown("gRPC call failed right after it was created")); 658 } 659 } 660 661 // Request to cancel the current streaming call. Non-blocking. CancelCall()662 void CancelCall() { 663 mutex_lock l(mu_); 664 if (state_ == nullptr) { 665 return; 666 } 667 context_->TryCancel(); 668 state_ = nullptr; 669 } 670 671 private: CreateStreamingState()672 void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) { 673 // ClientContext cannot be reused across calls. 674 context_ = std::make_shared<::grpc::ClientContext>(); 675 // Don't immediately fail StartCall if the channel is not ready. Wait for 676 // the channel to become ready. 677 context_->set_wait_for_ready(true); 678 679 std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call = 680 stub_->PrepareCall(context_.get(), method_, cq_); 681 682 state_.reset(new StreamingRPCState<Response>(std::move(call), context_)); 683 } 684 685 mutable mutex mu_; 686 687 // Both are thread-safe 688 ::grpc::GenericStub* const stub_; 689 ::grpc::CompletionQueue* const cq_; 690 691 // Does not need synchronization since it is constant. 692 const ::grpc::string method_; 693 694 std::shared_ptr<::grpc::ClientContext> context_ GUARDED_BY(mu_); 695 core::RefCountPtr<StreamingRPCState<Response>> state_ GUARDED_BY(mu_); 696 }; 697 698 } // namespace tensorflow 699 700 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 701