1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 18 19 #include <queue> 20 #include <utility> 21 22 #include "grpcpp/generic/generic_stub.h" 23 #include "grpcpp/grpcpp.h" 24 #include "tensorflow/core/distributed_runtime/call_options.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 28 #include "tensorflow/core/lib/core/refcount.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/core/threadpool.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/notification.h" 34 #include "tensorflow/core/util/env_var.h" 35 36 namespace tensorflow { 37 38 // Object allocated per active RPC. 39 // Manage the state of a single asynchronous RPC request. If `max_retries` 40 // is greater than 0, the request will be retried for any transient failures. 41 template <class Response> 42 class RPCState : public GrpcClientCQTag { 43 public: 44 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 45 const ::grpc::string& method, const protobuf::Message& request, 46 Response* response, StatusCallback done, CallOptions* call_opts, 47 thread::ThreadPool* threadpool, int32 max_retries = 0, 48 bool fail_fast = true, const string* target = nullptr) 49 : RPCState( 50 stub, cq, method, request, response, std::move(done), call_opts, 51 threadpool, 52 // 1) If GRPC_FAIL_FAST is set to 'true' or 'false', 53 // fail_fast=$GRPC_FAIL_FAST. See b/141948186. 54 // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the 55 // fail_fast from the caller. See b/140260119. 56 // 57 // Current default for PLATFORM_GOOGLE: use caller fail_fast; 58 // Current default for open source: fail_fast=false. 59 // 60 // NOTE: Callers mostly set fail_fast=true to prevent job hanging 61 // on worker task failures, except a few cases such as GetStatus 62 // in cluster initialization and collective param resolution. 63 [fail_fast, &done]() -> bool { 64 string fail_fast_env; 65 #if defined(PLATFORM_GOOGLE) 66 TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller", 67 &fail_fast_env)); 68 #else 69 TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "false", 70 &fail_fast_env)); 71 #endif // PLATFORM_GOOGLE 72 string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env); 73 if (fail_fast_env_lower == "true") { 74 return true; 75 } else if (fail_fast_env_lower == "use_caller") { 76 return fail_fast; 77 } else if (fail_fast_env_lower == "false") { 78 return false; 79 } else { 80 string error_message = strings::StrCat( 81 "Invalid GRPC_FAIL_FAST config: ", fail_fast_env); 82 LOG(WARNING) << error_message; 83 done(errors::InvalidArgument(error_message)); 84 return false; 85 } 86 }(), 87 (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries, 88 target) { 89 } 90 91 template <typename Request> RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64 timeout_in_ms,int32 max_retries,const string * target)92 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 93 const ::grpc::string& method, const Request& request, 94 Response* response, StatusCallback done, CallOptions* call_opts, 95 thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms, 96 int32 max_retries, const string* target) 97 : call_opts_(call_opts), 98 threadpool_(threadpool), 99 done_(std::move(done)), 100 timeout_in_ms_(timeout_in_ms), 101 max_retries_(max_retries), 102 cq_(cq), 103 stub_(stub), 104 method_(method), 105 fail_fast_(fail_fast), 106 target_(target) { 107 response_ = response; 108 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_); 109 if (!s.ok()) { 110 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 111 << s.error_message(); 112 // Skip retry logic if we fail to parse our request. 113 done_(FromGrpcStatus(s)); 114 delete this; 115 return; 116 } 117 StartCall(); 118 } 119 StartCall()120 void StartCall() { 121 context_.reset(new ::grpc::ClientContext()); 122 context_->set_wait_for_ready(!fail_fast_); 123 if (timeout_in_ms_ > 0) { 124 context_->set_deadline( 125 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN)); 126 } 127 if (call_opts_) { 128 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); }); 129 } 130 131 VLOG(2) << "Starting call: " << method_; 132 133 call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_); 134 call_->StartCall(); 135 call_->Finish(&response_buf_, &status_, this); 136 } 137 OnCompleted(bool ok)138 void OnCompleted(bool ok) override { 139 if (call_opts_) { 140 call_opts_->ClearCancelCallback(); 141 } 142 143 VLOG(2) << "Completed call: " << method_; 144 145 Status s = FromGrpcStatus(status_); 146 if (s.ok() && !ok) { 147 // Since this function is only being used for processing the response 148 // to Finish for client-side unary calls, ok should never be false 149 s.Update( 150 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 151 "not. This should never happen.")); 152 } 153 154 if (s.ok()) { 155 if (threadpool_) { 156 // Run parse and callback in another thread, returning this 157 // one to service more RPCs. 158 threadpool_->Schedule([this]() { ParseAndCallDone(); }); 159 } else { 160 ParseAndCallDone(); 161 } 162 return; 163 } 164 165 VLOG(1) << method_ << " returned with non-ok status: " << s 166 << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n" 167 << context_->debug_error_string(); 168 // Retry if we have any attempts left 169 if (++num_retries_ <= max_retries_ && 170 (errors::IsUnavailable(s) || errors::IsUnknown(s))) { 171 response_buf_.Clear(); 172 VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_ 173 << " of " << max_retries_; 174 // TODO(b/139945426) Allow user to configure the retry backoff time. 175 StartCall(); 176 } else { 177 // Attach additional GRPC error information if any to the final status 178 string error_msg = s.error_message(); 179 strings::StrAppend(&error_msg, "\nAdditional GRPC error information"); 180 if (target_) { 181 strings::StrAppend(&error_msg, " from remote target ", *target_); 182 } 183 strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string()); 184 s = Status(s.code(), error_msg); 185 // Always treat gRPC cancellation as a derived error. This ensures that 186 // other error types are preferred during status aggregation. (gRPC 187 // cancellation messages do not contain the original status message). 188 if (s.code() == tensorflow::error::Code::CANCELLED) { 189 s = StatusGroup::MakeDerived(s); 190 } 191 192 done_(s); 193 delete this; 194 } 195 } 196 ParseAndCallDone()197 void ParseAndCallDone() { 198 Status s; 199 if (!GrpcMaybeParseProto(&response_buf_, response_)) { 200 s.Update(errors::Internal("could not parse rpc response")); 201 } 202 done_(s); 203 delete this; 204 } 205 206 private: 207 CallOptions* call_opts_; 208 std::unique_ptr<::grpc::ClientContext> context_; 209 thread::ThreadPool* threadpool_; 210 std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; 211 Response* response_; 212 ::grpc::ByteBuffer request_buf_; 213 ::grpc::ByteBuffer response_buf_; 214 ::grpc::Status status_; 215 StatusCallback done_; 216 int64 timeout_in_ms_; 217 218 size_t num_retries_ = 0; 219 size_t max_retries_; 220 221 ::grpc::CompletionQueue* cq_; 222 ::grpc::GenericStub* stub_; 223 ::grpc::string method_; 224 bool fail_fast_; 225 const string* target_; 226 }; 227 228 // Represents state associated with one streaming RPC call. 229 // Similarly to above, we extract the methods of StreamingRPCState that don't 230 // need to be templated into this abstract class. 231 // Currently, *StreamingRPCState does not support client closing the call as 232 // there is no use case for it - current clients keep the streaming call open 233 // as long as possible. If/when the need arises, support can be added 234 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag 235 // TagType::kClientFinished and handling the completion in a new callback. 236 class UntypedStreamingRPCState : public core::RefCounted { 237 public: 238 virtual void CallStarted(bool ok) = 0; 239 virtual void RequestWriteCompleted(bool ok) = 0; 240 virtual void ResponseReadCompleted(bool ok) = 0; 241 virtual void CallFinished(bool ok) = 0; 242 243 virtual string DebugString() const = 0; 244 245 class Tag : public GrpcClientCQTag { 246 public: 247 // One enum value per supported callback. 248 enum class TagType { 249 kCallStarted, 250 kRequestWriteCompleted, 251 kResponseReadCompleted, 252 kCallFinished, 253 }; 254 255 Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type); 256 257 // Calls the callback associated with this tag and Unrefs 258 // `this->streaming_state_`. 259 void OnCompleted(bool ok) override; 260 261 private: 262 // OnCompleted() consumes on reference each time it is called. 263 UntypedStreamingRPCState* const streaming_state_; 264 const Tag::TagType type_; 265 }; 266 }; 267 268 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type); 269 270 // Represents a single request/response exchange between client and the server. 271 // A single streaming call contains a sequence of exchanges. Besides the 272 // messages, exchange contains: 273 // - the user callback to invoke when exchange completes (response is received 274 // or an error occurs). 275 // - The current state of the exchange. 276 class Exchange { 277 public: 278 enum class State { 279 kExchangeCreated, 280 kRequestWriteIssued, 281 kRequestWriteCompleted, 282 kResponseReadIssued, 283 }; 284 Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)285 Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response, 286 StatusCallback cb, string debug_string) 287 : state_(State::kExchangeCreated), 288 request_buf_(request_buf), 289 response_(response), 290 cb_(std::move(cb)), 291 debug_string_(std::move(debug_string)) {} 292 request_buf()293 const ::grpc::ByteBuffer& request_buf() { return request_buf_; } response_buf()294 ::grpc::ByteBuffer* response_buf() { return &response_buf_; } 295 MarkRequestWriteIssued()296 void MarkRequestWriteIssued() { 297 DCHECK(state_ == State::kExchangeCreated); 298 state_ = State::kRequestWriteIssued; 299 } MarkRequestWriteCompleted()300 void MarkRequestWriteCompleted() { 301 DCHECK(state_ == State::kRequestWriteIssued); 302 state_ = State::kRequestWriteCompleted; 303 } MarkResponseReadIssued()304 void MarkResponseReadIssued() { 305 DCHECK(state_ == State::kRequestWriteCompleted); 306 state_ = State::kResponseReadIssued; 307 } 308 309 // If `status` is success, completes this exchange by parsing the 310 // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the 311 // callback with `status`. 312 void Complete(Status status); 313 state()314 const State& state() const { return state_; } 315 316 string DebugString() const; 317 318 private: 319 State state_; 320 ::grpc::ByteBuffer request_buf_; 321 ::grpc::ByteBuffer response_buf_; 322 protobuf::Message* response_; 323 StatusCallback cb_; 324 string debug_string_; 325 }; 326 327 const char* ToString(Exchange::State s); 328 329 std::ostream& operator<<(std::ostream& os, const Exchange::State& state); 330 331 // Represents a queue of exchanges. 332 // When a client sends a new request a new exchange is created and added to the 333 // end of the queue. Completed exchanges are popped from the front of the queue. 334 // An explicit exchange queue is needed to brdige the client, which can send new 335 // requests at any time, with gRPC infrastructure, which can handle a single 336 // read and a single write request at a time. 337 // 338 // As the exchange progresses (request sending initiated, request sending 339 // completed, response reading initiated) the queue helps to make sure that the 340 // right operation is issued on the right exchange at the right time. 341 // 342 // To satisfy gRPC constraints, the states of exchanges must be as follows 343 // starting from the front of the queue: 344 // - 0 or 1 exchange in kResponseReadIssued state 345 // - 0 or more exchanges in kRequestWriteCompleted state 346 // - 0 or 1 exchange in kRequestWriteIssued state 347 // - 0 or more exchanges in kExchangeCreated state 348 // 349 // Thread-compatible. 350 class ExchangeQueue { 351 public: 352 // Creates a new exchange and adds it to the end of the queue. 353 void Emplace(const ::grpc::ByteBuffer& request_buf, 354 protobuf::Message* response, StatusCallback cb, 355 std::string debug_string); 356 357 // Returns an exchange for which we can initiate request writing, if any. 358 // Returns nullptr if there is no such exchange. 359 Exchange* GetReadyForRequestWriting(); 360 361 // Returns an exchange for which we can initiate response reading, if any. 362 // Returns nullptr if there is no such exchange. 363 Exchange* GetReadyForResponseReading(); 364 365 // Changes the state of the exchange that is current in kRequestWriteIssued 366 // state to kRequestWriteCompleted state. 367 // REQUIRES: There is an exchange in kRequestWriteIssued state. 368 void MarkRequestWriteCompleted(); 369 370 // Returns the exchange at the front of the queue. 371 // REQUIRES: ExchangeQueue is not empty. 372 Exchange& GetFront(); 373 374 // Removes the exchange at the front of the queue. 375 // REQUIRES: ExchangeQueue is not empty. 376 void PopFront(); 377 378 // Returns a string containing addresses and states of all exchanges in this 379 // queue. 380 string DebugString() const; 381 382 // Swaps the contents of this and `other`. 383 void Swap(ExchangeQueue* other); 384 385 // Completes all exchanges in this with `status`. 386 void CompleteAll(Status status); 387 CallStarted()388 void CallStarted() { call_started_ = true; } 389 390 private: 391 // Does nothing by default. Turn on VLOG(5) to enable. 392 // Checks that this ExchangeQueue is in a valid state. 393 // Kills the process if not. 394 void CheckInvariants(); 395 396 // We can't process any exchanges until the call has started. 397 bool call_started_ = false; 398 399 // std::queue is based on std::deque by default. std::deque provides 400 // fairly strong iterator stability. 401 std::deque<Exchange> exchanges_; 402 }; // namespace tensorflow 403 404 // Represents state associated with one streaming RPC call. 405 // Thread-safe 406 template <class Response> 407 class StreamingRPCState : public UntypedStreamingRPCState { 408 public: 409 // Default behavior is to set fail_fast = False and handle timeouts 410 // manually. StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)411 StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call, 412 const std::shared_ptr<::grpc::ClientContext>& context) 413 : context_(context), call_(std::move(call)), call_state_(State::kActive) { 414 Ref(); 415 VLOG(3) << "Created new StreamingRPCState " << this; 416 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall"; 417 call_->StartCall(&call_started_tag_); 418 } 419 ~StreamingRPCState()420 ~StreamingRPCState() override { 421 VLOG(3) << "Destructing StreamingRPCState " << this; 422 } 423 424 // Attempts to send the next request. `done` is invoked when 425 // `response` has been filled with the data from the server, or if there 426 // is an error. `done` can be invoked before SendNextRequest returns. 427 // Return `true` if the call is alive and the `done` callback has or 428 // will be invoked. If the call is dead, returns `false`. `done` callback 429 // will not be invoked in this case. 430 // REQUIRES: The call has been started, i.e. WaitForCallStarted() has 431 // returned. SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)432 bool SendNextRequest(const protobuf::Message& request, Response* response, 433 const StatusCallback& done) { 434 ::grpc::ByteBuffer request_buf; 435 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf); 436 if (!s.ok()) { 437 Status status = FromGrpcStatus(s); 438 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 439 << status.ToString(); 440 done(status); 441 return true; 442 } 443 444 mutex_lock l(mu_); 445 if (call_state_ != State::kActive) { 446 // `done` is not invoked intentionally. 447 return false; 448 } 449 if (VLOG_IS_ON(3)) { 450 // If vlog 3 is enabled, include first 100 chars of request as debug 451 // string. 452 exchanges_.Emplace(request_buf, response, done, 453 request.ShortDebugString().substr(0, 100)); 454 } else { 455 exchanges_.Emplace(request_buf, response, done, ""); 456 } 457 MaybeIssueRequestWriteLocked(); 458 return true; 459 } 460 CallStarted(bool ok)461 void CallStarted(bool ok) override { 462 VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok 463 << ")"; 464 mutex_lock l(mu_); 465 if (!ok) { 466 call_state_ = State::kDone; 467 return; 468 } 469 exchanges_.CallStarted(); 470 // Now that the call has started, we can write our first request, if any. 471 MaybeIssueRequestWriteLocked(); 472 } 473 RequestWriteCompleted(bool ok)474 void RequestWriteCompleted(bool ok) override { 475 VLOG(3) << "StreamingRPCState(" << this 476 << ")::RequestWriteCompleted(ok=" << ok << ")"; 477 mu_.lock(); 478 if (call_state_ != State::kActive) { 479 mu_.unlock(); 480 return; 481 } 482 exchanges_.MarkRequestWriteCompleted(); 483 // Issue ResponseRead regardless of OK status on completing RequestWrite. 484 // If the underlying completion queue is in Not-OK status due to previous 485 // request failuress (i.e., `ok` from `Next` call on completion queue is 486 // False), delay the error in ResponseRead so we can get the remote error 487 // message from response buffer. 488 MaybeIssueResponseReadLocked(); 489 490 if (ok) { 491 MaybeIssueRequestWriteLocked(); 492 } 493 mu_.unlock(); 494 } 495 ResponseReadCompleted(bool ok)496 void ResponseReadCompleted(bool ok) override { 497 VLOG(3) << "StreamingRPCState(" << this 498 << ")::ResponseReadCompleted(ok=" << ok << ")"; 499 mu_.lock(); 500 if (call_state_ != State::kActive) { 501 mu_.unlock(); 502 return; 503 } 504 if (!ok) { 505 IssueCallFinishLocked(); 506 mu_.unlock(); 507 return; 508 } 509 510 // Complete the exchange without holding the lock because user's 511 // callback can call back into this RPC code resulting in a deadlock. 512 // No other thread can pop this exchange while we release the lock because 513 // this is the only method that pops exchanges and it is called from a 514 // single thread that waits on completion queue events. 515 Exchange* e; 516 e = &exchanges_.GetFront(); 517 mu_.unlock(); 518 519 e->Complete(Status::OK()); 520 521 { 522 mutex_lock l(mu_); 523 exchanges_.PopFront(); 524 MaybeIssueResponseReadLocked(); 525 } 526 } 527 CallFinished(bool ok)528 void CallFinished(bool ok) override { 529 VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok 530 << ")"; 531 mu_.lock(); 532 DCHECK(call_state_ != State::kActive); 533 if (call_state_ != State::kFinishing) { 534 mu_.unlock(); 535 return; 536 } 537 538 Status s = FromGrpcStatus(call_status_); 539 if (s.ok() && !ok) { 540 s.Update( 541 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 542 "not. This should never happen.", 543 context_->debug_error_string())); 544 } 545 // unlocks mu_ 546 MarkDoneAndCompleteExchanges(s); 547 } 548 DebugString()549 string DebugString() const override { 550 mutex_lock l(mu_); 551 return exchanges_.DebugString(); 552 } 553 554 private: 555 enum class State { 556 kActive, 557 kFinishing, 558 kDone, 559 }; 560 MarkDoneAndCompleteExchanges(Status status)561 void MarkDoneAndCompleteExchanges(Status status) 562 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) { 563 call_state_ = State::kDone; 564 VLOG(2) << "Ending gRPC streaming call on the client side due to " 565 << status.ToString(); 566 // Swap the exchanges_ into a temporary ExchangeQueue so that we can 567 // complete all exchanges without holding mu_ in case user callback 568 // reach back into this. This should be impossible now, but safer for 569 // the future. 570 ExchangeQueue queue; 571 exchanges_.Swap(&queue); 572 mu_.unlock(); 573 queue.CompleteAll(status); 574 } 575 MaybeIssueRequestWriteLocked()576 void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 577 Exchange* exchange = exchanges_.GetReadyForRequestWriting(); 578 if (exchange == nullptr) { 579 // There are no queued exchanges, there is already an outstanding write, 580 // or there are no just created exchanges. 581 return; 582 } 583 exchange->MarkRequestWriteIssued(); 584 Ref(); 585 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write"; 586 call_->Write(exchange->request_buf(), &request_write_completed_tag_); 587 } 588 MaybeIssueResponseReadLocked()589 void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 590 Exchange* exchange = exchanges_.GetReadyForResponseReading(); 591 if (exchange == nullptr) { 592 return; 593 } 594 exchange->MarkResponseReadIssued(); 595 Ref(); 596 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read"; 597 call_->Read(exchange->response_buf(), &response_read_completed_tag_); 598 } 599 IssueCallFinishLocked()600 void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 601 call_state_ = State::kFinishing; 602 Ref(); 603 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish"; 604 // We call finish in response to completed (with error) response reading tag 605 // on some exchange. We let this exchange hang in ResponseReadIssued state. 606 // ExchangeQueue makes sure that there is at most one exchange in this 607 // state. So, no new reads will be issued. 608 call_->Finish(&call_status_, &finished_tag_); 609 } 610 611 // Holds state for a single request/response exchange between the client 612 // and the server. 613 typedef typename UntypedStreamingRPCState::Tag Tag; 614 615 // Order of context_ and call_ is important because context_ must outlive 616 // call_. 617 const std::shared_ptr<const ::grpc::ClientContext> context_; 618 std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_; 619 620 mutable mutex mu_; 621 ExchangeQueue exchanges_ TF_GUARDED_BY(mu_); 622 State call_state_ TF_GUARDED_BY(mu_); 623 ::grpc::Status call_status_ TF_GUARDED_BY(mu_); 624 625 // We can get away with having single instances of these tags per 626 // StreamingRPCState because we make sure (as gRPC requires) that 627 // there is at most one outstanding Read and at most one outstanding Write 628 // in the completion queue. 629 // Tags are immutable. No need to guard them. 630 Tag call_started_tag_{this, Tag::TagType::kCallStarted}; 631 Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted}; 632 Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted}; 633 Tag finished_tag_{this, Tag::TagType::kCallFinished}; 634 }; 635 636 // Creates streaming calls and dispatches requests to them. 637 // In the common case, the client would create a StreamingRPCDispatcher for 638 // each bidirectional streaming RPC it might want to make. The first time, it 639 // calls SendNextRequest, a streaming call is initiated and the request is 640 // sent within this call. Initiation of the call blocks the client. If there are 641 // no errors, subsequent calls to SendNextRequest would use the already active 642 // call. If there was an error, the call object will be destroyed after all 643 // the callbacks for outstanding requests have been invoked. The next call to 644 // SendNextRequest will initiate a new call. 645 // 646 // Callbacks that are part of the same call, are invoked in the order they were 647 // provided, but callbacks across calls (a failed and a new one) can be invoked 648 // in any order. 649 // 650 // Thread-safe. 651 template <class Response> 652 class StreamingRPCDispatcher { 653 public: StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)654 StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 655 const ::grpc::string& method) 656 : stub_(stub), cq_(cq), method_(method) {} 657 658 // Attempts to send the next request. If there is no active streaming call, 659 // starts one and sends the request on top of it. `done` is invoked when 660 // `response` has been filled with the data from the server, or if there 661 // is an error. `done` can be invoked before SendNextRequest returns. SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)662 void SendNextRequest(const protobuf::Message& request, Response* response, 663 StatusCallback done) { 664 mutex_lock l(mu_); 665 if (state_ == nullptr) { 666 CreateStreamingState(); 667 } 668 669 bool is_call_alive = state_->SendNextRequest(request, response, done); 670 if (is_call_alive) { 671 return; 672 } 673 674 // The attempt to send failed because the call was dead, create a new 675 // call and try again. When the call is dead SendNextRequest does not call 676 // `done`. 677 CreateStreamingState(); 678 679 is_call_alive = state_->SendNextRequest(request, response, done); 680 if (!is_call_alive) { 681 // Consider retrying to create and start a call few more times. 682 done(errors::Unknown("gRPC call failed right after it was created")); 683 } 684 } 685 686 // Request to cancel the current streaming call. Non-blocking. CancelCall()687 void CancelCall() { 688 mutex_lock l(mu_); 689 if (state_ == nullptr) { 690 return; 691 } 692 context_->TryCancel(); 693 state_ = nullptr; 694 } 695 696 private: CreateStreamingState()697 void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 698 // ClientContext cannot be reused across calls. 699 context_ = std::make_shared<::grpc::ClientContext>(); 700 // Don't immediately fail StartCall if the channel is not ready. Wait for 701 // the channel to become ready. 702 context_->set_wait_for_ready(true); 703 704 std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call = 705 stub_->PrepareCall(context_.get(), method_, cq_); 706 707 state_.reset(new StreamingRPCState<Response>(std::move(call), context_)); 708 } 709 710 mutable mutex mu_; 711 712 // Both are thread-safe 713 ::grpc::GenericStub* const stub_; 714 ::grpc::CompletionQueue* const cq_; 715 716 // Does not need synchronization since it is constant. 717 const ::grpc::string method_; 718 719 std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_); 720 core::RefCountPtr<StreamingRPCState<Response>> state_ TF_GUARDED_BY(mu_); 721 }; 722 723 } // namespace tensorflow 724 725 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 726