1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 18 19 #include <queue> 20 #include <utility> 21 22 #include "grpcpp/generic/generic_stub.h" 23 #include "grpcpp/grpcpp.h" 24 #include "tensorflow/core/distributed_runtime/call_options.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 28 #include "tensorflow/core/lib/core/refcount.h" 29 #include "tensorflow/core/lib/core/status.h" 30 #include "tensorflow/core/lib/core/threadpool.h" 31 #include "tensorflow/core/lib/strings/strcat.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/notification.h" 34 #include "tensorflow/core/util/env_var.h" 35 36 namespace tensorflow { 37 38 // Object allocated per active RPC. 39 // Manage the state of a single asynchronous RPC request. If `max_retries` 40 // is greater than 0, the request will be retried for any transient failures. 41 template <class Response> 42 class RPCState : public GrpcClientCQTag { 43 public: 44 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 45 const ::grpc::string& method, const protobuf::Message& request, 46 Response* response, StatusCallback done, CallOptions* call_opts, 47 thread::ThreadPool* threadpool, int32_t max_retries = 0, 48 bool fail_fast = true, const string* target = nullptr) 49 : RPCState( 50 stub, cq, method, request, response, std::move(done), call_opts, 51 threadpool, 52 // 1) If GRPC_FAIL_FAST is set to 'true' or 'false', 53 // fail_fast=$GRPC_FAIL_FAST. See b/141948186. 54 // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the 55 // fail_fast from the caller. See b/140260119. 56 // 57 // Current default: use caller's fail_fast argument. 58 // 59 // NOTE: Callers mostly set fail_fast=true to prevent job hanging 60 // on worker task failures, except a few cases such as GetStatus 61 // in cluster initialization and collective param resolution. 62 [fail_fast, &done]() -> bool { 63 string fail_fast_env; 64 TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller", 65 &fail_fast_env)); 66 string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env); 67 if (fail_fast_env_lower == "true") { 68 return true; 69 } else if (fail_fast_env_lower == "use_caller") { 70 return fail_fast; 71 } else if (fail_fast_env_lower == "false") { 72 return false; 73 } else { 74 string error_message = strings::StrCat( 75 "Invalid GRPC_FAIL_FAST config: ", fail_fast_env); 76 LOG(WARNING) << error_message; 77 done(errors::InvalidArgument(error_message)); 78 return false; 79 } 80 }(), 81 (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries, 82 target) {} 83 84 template <typename Request> RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64_t timeout_in_ms,int32_t max_retries,const string * target)85 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 86 const ::grpc::string& method, const Request& request, 87 Response* response, StatusCallback done, CallOptions* call_opts, 88 thread::ThreadPool* threadpool, bool fail_fast, 89 int64_t timeout_in_ms, int32_t max_retries, const string* target) 90 : call_opts_(call_opts), 91 threadpool_(threadpool), 92 done_(std::move(done)), 93 timeout_in_ms_(timeout_in_ms), 94 max_retries_(max_retries), 95 cq_(cq), 96 stub_(stub), 97 method_(method), 98 fail_fast_(fail_fast), 99 target_(target) { 100 response_ = response; 101 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_); 102 if (!s.ok()) { 103 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 104 << s.error_message(); 105 // Skip retry logic if we fail to parse our request. 106 done_(FromGrpcStatus(s)); 107 delete this; 108 return; 109 } 110 StartCall(); 111 } 112 StartCall()113 void StartCall() { 114 context_.reset(new ::grpc::ClientContext()); 115 context_->set_wait_for_ready(!fail_fast_); 116 if (timeout_in_ms_ > 0) { 117 context_->set_deadline( 118 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN)); 119 } 120 if (call_opts_) { 121 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); }); 122 } 123 124 VLOG(2) << "Starting call: " << method_; 125 126 call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_); 127 call_->StartCall(); 128 call_->Finish(&response_buf_, &status_, this); 129 } 130 OnCompleted(bool ok)131 void OnCompleted(bool ok) override { 132 if (call_opts_) { 133 call_opts_->ClearCancelCallback(); 134 } 135 136 VLOG(2) << "Completed call: " << method_; 137 138 Status s = FromGrpcStatus(status_); 139 if (s.ok() && !ok) { 140 // Since this function is only being used for processing the response 141 // to Finish for client-side unary calls, ok should never be false 142 s.Update( 143 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 144 "not. This should never happen.")); 145 } 146 147 if (s.ok()) { 148 if (threadpool_) { 149 // Run parse and callback in another thread, returning this 150 // one to service more RPCs. 151 threadpool_->Schedule([this]() { ParseAndCallDone(); }); 152 } else { 153 ParseAndCallDone(); 154 } 155 return; 156 } 157 158 VLOG(1) << method_ << " returned with non-ok status: " << s 159 << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n" 160 << context_->debug_error_string(); 161 // Retry if we have any attempts left 162 if (++num_retries_ <= max_retries_ && 163 (errors::IsUnavailable(s) || errors::IsUnknown(s))) { 164 response_buf_.Clear(); 165 VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_ 166 << " of " << max_retries_; 167 168 ComputeRetryBackoffMs(/*min_backoff_ms=*/1, /*max_backoff_ms=*/10000); 169 int64_t backoff_us = retry_backoff_ms_ * 1000; 170 Env::Default()->SchedClosureAfter(/*micros=*/backoff_us, 171 [this]() { StartCall(); }); 172 } else { 173 // Attach additional GRPC error information if any to the final status 174 string error_msg = s.error_message(); 175 strings::StrAppend(&error_msg, "\nAdditional GRPC error information"); 176 if (target_) { 177 strings::StrAppend(&error_msg, " from remote target ", *target_); 178 } 179 strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string()); 180 s = Status(s.code(), error_msg); 181 // Always treat gRPC cancellation as a derived error. This ensures that 182 // other error types are preferred during status aggregation. (gRPC 183 // cancellation messages do not contain the original status message). 184 if (s.code() == tensorflow::error::Code::CANCELLED) { 185 s = StatusGroup::MakeDerived(s); 186 } 187 188 done_(s); 189 delete this; 190 } 191 } 192 ParseAndCallDone()193 void ParseAndCallDone() { 194 Status s; 195 if (!GrpcMaybeParseProto(&response_buf_, response_)) { 196 s.Update(errors::Internal("could not parse rpc response")); 197 } 198 done_(s); 199 delete this; 200 } 201 202 private: ComputeRetryBackoffMs(int min_backoff_ms,int max_backoff_ms)203 void ComputeRetryBackoffMs(int min_backoff_ms, int max_backoff_ms) { 204 constexpr float kBackoffBase = 1.3; 205 if (retry_backoff_ms_ < 0) { 206 retry_backoff_ms_ = min_backoff_ms; 207 } else { 208 retry_backoff_ms_ *= kBackoffBase; 209 if (retry_backoff_ms_ > max_backoff_ms) { 210 retry_backoff_ms_ = max_backoff_ms; 211 } 212 } 213 } 214 215 CallOptions* call_opts_; 216 std::unique_ptr<::grpc::ClientContext> context_; 217 thread::ThreadPool* threadpool_; 218 std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; 219 Response* response_; 220 ::grpc::ByteBuffer request_buf_; 221 ::grpc::ByteBuffer response_buf_; 222 ::grpc::Status status_; 223 StatusCallback done_; 224 int64 timeout_in_ms_; 225 226 size_t num_retries_ = 0; 227 size_t max_retries_; 228 double retry_backoff_ms_ = -1; 229 230 ::grpc::CompletionQueue* cq_; 231 ::grpc::GenericStub* stub_; 232 ::grpc::string method_; 233 bool fail_fast_; 234 const string* target_; 235 }; 236 237 // Represents state associated with one streaming RPC call. 238 // Similarly to above, we extract the methods of StreamingRPCState that don't 239 // need to be templated into this abstract class. 240 // Currently, *StreamingRPCState does not support client closing the call as 241 // there is no use case for it - current clients keep the streaming call open 242 // as long as possible. If/when the need arises, support can be added 243 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag 244 // TagType::kClientFinished and handling the completion in a new callback. 245 class UntypedStreamingRPCState : public core::RefCounted { 246 public: 247 virtual void CallStarted(bool ok) = 0; 248 virtual void RequestWriteCompleted(bool ok) = 0; 249 virtual void ResponseReadCompleted(bool ok) = 0; 250 virtual void CallFinished(bool ok) = 0; 251 252 virtual string DebugString() const = 0; 253 254 class Tag : public GrpcClientCQTag { 255 public: 256 // One enum value per supported callback. 257 enum class TagType { 258 kCallStarted, 259 kRequestWriteCompleted, 260 kResponseReadCompleted, 261 kCallFinished, 262 }; 263 264 Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type); 265 266 // Calls the callback associated with this tag and Unrefs 267 // `this->streaming_state_`. 268 void OnCompleted(bool ok) override; 269 270 private: 271 // OnCompleted() consumes on reference each time it is called. 272 UntypedStreamingRPCState* const streaming_state_; 273 const Tag::TagType type_; 274 }; 275 }; 276 277 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type); 278 279 // Represents a single request/response exchange between client and the server. 280 // A single streaming call contains a sequence of exchanges. Besides the 281 // messages, exchange contains: 282 // - the user callback to invoke when exchange completes (response is received 283 // or an error occurs). 284 // - The current state of the exchange. 285 class Exchange { 286 public: 287 enum class State { 288 kExchangeCreated, 289 kRequestWriteIssued, 290 kRequestWriteCompleted, 291 kResponseReadIssued, 292 }; 293 Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)294 Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response, 295 StatusCallback cb, string debug_string) 296 : state_(State::kExchangeCreated), 297 request_buf_(request_buf), 298 response_(response), 299 cb_(std::move(cb)), 300 debug_string_(std::move(debug_string)) {} 301 request_buf()302 const ::grpc::ByteBuffer& request_buf() { return request_buf_; } response_buf()303 ::grpc::ByteBuffer* response_buf() { return &response_buf_; } 304 MarkRequestWriteIssued()305 void MarkRequestWriteIssued() { 306 DCHECK(state_ == State::kExchangeCreated); 307 state_ = State::kRequestWriteIssued; 308 } MarkRequestWriteCompleted()309 void MarkRequestWriteCompleted() { 310 DCHECK(state_ == State::kRequestWriteIssued); 311 state_ = State::kRequestWriteCompleted; 312 } MarkResponseReadIssued()313 void MarkResponseReadIssued() { 314 DCHECK(state_ == State::kRequestWriteCompleted); 315 state_ = State::kResponseReadIssued; 316 } 317 318 // If `status` is success, completes this exchange by parsing the 319 // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the 320 // callback with `status`. 321 void Complete(Status status); 322 state()323 const State& state() const { return state_; } 324 325 string DebugString() const; 326 327 private: 328 State state_; 329 ::grpc::ByteBuffer request_buf_; 330 ::grpc::ByteBuffer response_buf_; 331 protobuf::Message* response_; 332 StatusCallback cb_; 333 string debug_string_; 334 }; 335 336 const char* ToString(Exchange::State s); 337 338 std::ostream& operator<<(std::ostream& os, const Exchange::State& state); 339 340 // Represents a queue of exchanges. 341 // When a client sends a new request a new exchange is created and added to the 342 // end of the queue. Completed exchanges are popped from the front of the queue. 343 // An explicit exchange queue is needed to brdige the client, which can send new 344 // requests at any time, with gRPC infrastructure, which can handle a single 345 // read and a single write request at a time. 346 // 347 // As the exchange progresses (request sending initiated, request sending 348 // completed, response reading initiated) the queue helps to make sure that the 349 // right operation is issued on the right exchange at the right time. 350 // 351 // To satisfy gRPC constraints, the states of exchanges must be as follows 352 // starting from the front of the queue: 353 // - 0 or 1 exchange in kResponseReadIssued state 354 // - 0 or more exchanges in kRequestWriteCompleted state 355 // - 0 or 1 exchange in kRequestWriteIssued state 356 // - 0 or more exchanges in kExchangeCreated state 357 // 358 // Thread-compatible. 359 class ExchangeQueue { 360 public: 361 // Creates a new exchange and adds it to the end of the queue. 362 void Emplace(const ::grpc::ByteBuffer& request_buf, 363 protobuf::Message* response, StatusCallback cb, 364 std::string debug_string); 365 366 // Returns an exchange for which we can initiate request writing, if any. 367 // Returns nullptr if there is no such exchange. 368 Exchange* GetReadyForRequestWriting(); 369 370 // Returns an exchange for which we can initiate response reading, if any. 371 // Returns nullptr if there is no such exchange. 372 Exchange* GetReadyForResponseReading(); 373 374 // Changes the state of the exchange that is current in kRequestWriteIssued 375 // state to kRequestWriteCompleted state. 376 // REQUIRES: There is an exchange in kRequestWriteIssued state. 377 void MarkRequestWriteCompleted(); 378 379 // Returns the exchange at the front of the queue. 380 // REQUIRES: ExchangeQueue is not empty. 381 Exchange& GetFront(); 382 383 // Removes the exchange at the front of the queue. 384 // REQUIRES: ExchangeQueue is not empty. 385 void PopFront(); 386 387 // Returns a string containing addresses and states of all exchanges in this 388 // queue. 389 string DebugString() const; 390 391 // Swaps the contents of this and `other`. 392 void Swap(ExchangeQueue* other); 393 394 // Completes all exchanges in this with `status`. 395 void CompleteAll(Status status); 396 CallStarted()397 void CallStarted() { call_started_ = true; } 398 399 private: 400 // Does nothing by default. Turn on VLOG(5) to enable. 401 // Checks that this ExchangeQueue is in a valid state. 402 // Kills the process if not. 403 void CheckInvariants(); 404 405 // We can't process any exchanges until the call has started. 406 bool call_started_ = false; 407 408 // std::queue is based on std::deque by default. std::deque provides 409 // fairly strong iterator stability. 410 std::deque<Exchange> exchanges_; 411 }; // namespace tensorflow 412 413 // Represents state associated with one streaming RPC call. 414 // Thread-safe 415 template <class Response> 416 class StreamingRPCState : public UntypedStreamingRPCState { 417 public: 418 // Default behavior is to set fail_fast = False and handle timeouts 419 // manually. StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)420 StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call, 421 const std::shared_ptr<::grpc::ClientContext>& context) 422 : context_(context), call_(std::move(call)), call_state_(State::kActive) { 423 Ref(); 424 VLOG(3) << "Created new StreamingRPCState " << this; 425 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall"; 426 call_->StartCall(&call_started_tag_); 427 } 428 ~StreamingRPCState()429 ~StreamingRPCState() override { 430 VLOG(3) << "Destructing StreamingRPCState " << this; 431 } 432 433 // Attempts to send the next request. `done` is invoked when 434 // `response` has been filled with the data from the server, or if there 435 // is an error. `done` can be invoked before SendNextRequest returns. 436 // Return `true` if the call is alive and the `done` callback has or 437 // will be invoked. If the call is dead, returns `false`. `done` callback 438 // will not be invoked in this case. 439 // REQUIRES: The call has been started, i.e. WaitForCallStarted() has 440 // returned. SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)441 bool SendNextRequest(const protobuf::Message& request, Response* response, 442 const StatusCallback& done) { 443 ::grpc::ByteBuffer request_buf; 444 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf); 445 if (!s.ok()) { 446 Status status = FromGrpcStatus(s); 447 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 448 << status.ToString(); 449 done(status); 450 return true; 451 } 452 453 mutex_lock l(mu_); 454 if (call_state_ != State::kActive) { 455 // `done` is not invoked intentionally. 456 return false; 457 } 458 if (VLOG_IS_ON(3)) { 459 // If vlog 3 is enabled, include first 100 chars of request as debug 460 // string. 461 exchanges_.Emplace(request_buf, response, done, 462 request.ShortDebugString().substr(0, 100)); 463 } else { 464 exchanges_.Emplace(request_buf, response, done, ""); 465 } 466 MaybeIssueRequestWriteLocked(); 467 return true; 468 } 469 CallStarted(bool ok)470 void CallStarted(bool ok) override { 471 VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok 472 << ")"; 473 mutex_lock l(mu_); 474 if (!ok) { 475 call_state_ = State::kDone; 476 return; 477 } 478 exchanges_.CallStarted(); 479 // Now that the call has started, we can write our first request, if any. 480 MaybeIssueRequestWriteLocked(); 481 } 482 RequestWriteCompleted(bool ok)483 void RequestWriteCompleted(bool ok) override { 484 VLOG(3) << "StreamingRPCState(" << this 485 << ")::RequestWriteCompleted(ok=" << ok << ")"; 486 mu_.lock(); 487 if (call_state_ != State::kActive) { 488 mu_.unlock(); 489 return; 490 } 491 exchanges_.MarkRequestWriteCompleted(); 492 // Issue ResponseRead regardless of OK status on completing RequestWrite. 493 // If the underlying completion queue is in Not-OK status due to previous 494 // request failuress (i.e., `ok` from `Next` call on completion queue is 495 // False), delay the error in ResponseRead so we can get the remote error 496 // message from response buffer. 497 MaybeIssueResponseReadLocked(); 498 499 if (ok) { 500 MaybeIssueRequestWriteLocked(); 501 } 502 mu_.unlock(); 503 } 504 ResponseReadCompleted(bool ok)505 void ResponseReadCompleted(bool ok) override { 506 VLOG(3) << "StreamingRPCState(" << this 507 << ")::ResponseReadCompleted(ok=" << ok << ")"; 508 mu_.lock(); 509 if (call_state_ != State::kActive) { 510 mu_.unlock(); 511 return; 512 } 513 if (!ok) { 514 IssueCallFinishLocked(); 515 mu_.unlock(); 516 return; 517 } 518 519 // Complete the exchange without holding the lock because user's 520 // callback can call back into this RPC code resulting in a deadlock. 521 // No other thread can pop this exchange while we release the lock because 522 // this is the only method that pops exchanges and it is called from a 523 // single thread that waits on completion queue events. 524 Exchange* e; 525 e = &exchanges_.GetFront(); 526 mu_.unlock(); 527 528 e->Complete(Status::OK()); 529 530 { 531 mutex_lock l(mu_); 532 exchanges_.PopFront(); 533 MaybeIssueResponseReadLocked(); 534 } 535 } 536 CallFinished(bool ok)537 void CallFinished(bool ok) override { 538 VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok 539 << ")"; 540 mu_.lock(); 541 DCHECK(call_state_ != State::kActive); 542 if (call_state_ != State::kFinishing) { 543 mu_.unlock(); 544 return; 545 } 546 547 Status s = FromGrpcStatus(call_status_); 548 if (s.ok() && !ok) { 549 s.Update( 550 errors::Internal("GRPC status is okay but CompletionQueueStatus is " 551 "not. This should never happen.", 552 context_->debug_error_string())); 553 } 554 // unlocks mu_ 555 MarkDoneAndCompleteExchanges(s); 556 } 557 DebugString()558 string DebugString() const override { 559 mutex_lock l(mu_); 560 return exchanges_.DebugString(); 561 } 562 563 private: 564 enum class State { 565 kActive, 566 kFinishing, 567 kDone, 568 }; 569 MarkDoneAndCompleteExchanges(Status status)570 void MarkDoneAndCompleteExchanges(Status status) 571 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) { 572 call_state_ = State::kDone; 573 VLOG(2) << "Ending gRPC streaming call on the client side due to " 574 << status.ToString(); 575 // Swap the exchanges_ into a temporary ExchangeQueue so that we can 576 // complete all exchanges without holding mu_ in case user callback 577 // reach back into this. This should be impossible now, but safer for 578 // the future. 579 ExchangeQueue queue; 580 exchanges_.Swap(&queue); 581 mu_.unlock(); 582 queue.CompleteAll(status); 583 } 584 MaybeIssueRequestWriteLocked()585 void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 586 Exchange* exchange = exchanges_.GetReadyForRequestWriting(); 587 if (exchange == nullptr) { 588 // There are no queued exchanges, there is already an outstanding write, 589 // or there are no just created exchanges. 590 return; 591 } 592 exchange->MarkRequestWriteIssued(); 593 Ref(); 594 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write"; 595 call_->Write(exchange->request_buf(), &request_write_completed_tag_); 596 } 597 MaybeIssueResponseReadLocked()598 void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 599 Exchange* exchange = exchanges_.GetReadyForResponseReading(); 600 if (exchange == nullptr) { 601 return; 602 } 603 exchange->MarkResponseReadIssued(); 604 Ref(); 605 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read"; 606 call_->Read(exchange->response_buf(), &response_read_completed_tag_); 607 } 608 IssueCallFinishLocked()609 void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 610 call_state_ = State::kFinishing; 611 Ref(); 612 VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish"; 613 // We call finish in response to completed (with error) response reading tag 614 // on some exchange. We let this exchange hang in ResponseReadIssued state. 615 // ExchangeQueue makes sure that there is at most one exchange in this 616 // state. So, no new reads will be issued. 617 call_->Finish(&call_status_, &finished_tag_); 618 } 619 620 // Holds state for a single request/response exchange between the client 621 // and the server. 622 typedef typename UntypedStreamingRPCState::Tag Tag; 623 624 // Order of context_ and call_ is important because context_ must outlive 625 // call_. 626 const std::shared_ptr<const ::grpc::ClientContext> context_; 627 std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_; 628 629 mutable mutex mu_; 630 ExchangeQueue exchanges_ TF_GUARDED_BY(mu_); 631 State call_state_ TF_GUARDED_BY(mu_); 632 ::grpc::Status call_status_ TF_GUARDED_BY(mu_); 633 634 // We can get away with having single instances of these tags per 635 // StreamingRPCState because we make sure (as gRPC requires) that 636 // there is at most one outstanding Read and at most one outstanding Write 637 // in the completion queue. 638 // Tags are immutable. No need to guard them. 639 Tag call_started_tag_{this, Tag::TagType::kCallStarted}; 640 Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted}; 641 Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted}; 642 Tag finished_tag_{this, Tag::TagType::kCallFinished}; 643 }; 644 645 // Creates streaming calls and dispatches requests to them. 646 // In the common case, the client would create a StreamingRPCDispatcher for 647 // each bidirectional streaming RPC it might want to make. The first time, it 648 // calls SendNextRequest, a streaming call is initiated and the request is 649 // sent within this call. Initiation of the call blocks the client. If there are 650 // no errors, subsequent calls to SendNextRequest would use the already active 651 // call. If there was an error, the call object will be destroyed after all 652 // the callbacks for outstanding requests have been invoked. The next call to 653 // SendNextRequest will initiate a new call. 654 // 655 // Callbacks that are part of the same call, are invoked in the order they were 656 // provided, but callbacks across calls (a failed and a new one) can be invoked 657 // in any order. 658 // 659 // Thread-safe. 660 template <class Response> 661 class StreamingRPCDispatcher { 662 public: StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)663 StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 664 const ::grpc::string& method) 665 : stub_(stub), cq_(cq), method_(method) {} 666 667 // Attempts to send the next request. If there is no active streaming call, 668 // starts one and sends the request on top of it. `done` is invoked when 669 // `response` has been filled with the data from the server, or if there 670 // is an error. `done` can be invoked before SendNextRequest returns. SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)671 void SendNextRequest(const protobuf::Message& request, Response* response, 672 StatusCallback done) { 673 mutex_lock l(mu_); 674 if (state_ == nullptr) { 675 CreateStreamingState(); 676 } 677 678 bool is_call_alive = state_->SendNextRequest(request, response, done); 679 if (is_call_alive) { 680 return; 681 } 682 683 // The attempt to send failed because the call was dead, create a new 684 // call and try again. When the call is dead SendNextRequest does not call 685 // `done`. 686 CreateStreamingState(); 687 688 is_call_alive = state_->SendNextRequest(request, response, done); 689 if (!is_call_alive) { 690 // Consider retrying to create and start a call few more times. 691 done(errors::Unknown("gRPC call failed right after it was created")); 692 } 693 } 694 695 // Request to cancel the current streaming call. Non-blocking. CancelCall()696 void CancelCall() { 697 mutex_lock l(mu_); 698 if (state_ == nullptr) { 699 return; 700 } 701 context_->TryCancel(); 702 state_ = nullptr; 703 } 704 705 private: CreateStreamingState()706 void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { 707 // ClientContext cannot be reused across calls. 708 context_ = std::make_shared<::grpc::ClientContext>(); 709 // Don't immediately fail StartCall if the channel is not ready. Wait for 710 // the channel to become ready. 711 context_->set_wait_for_ready(true); 712 713 std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call = 714 stub_->PrepareCall(context_.get(), method_, cq_); 715 716 state_.reset(new StreamingRPCState<Response>(std::move(call), context_)); 717 } 718 719 mutable mutex mu_; 720 721 // Both are thread-safe 722 ::grpc::GenericStub* const stub_; 723 ::grpc::CompletionQueue* const cq_; 724 725 // Does not need synchronization since it is constant. 726 const ::grpc::string method_; 727 728 std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_); 729 core::RefCountPtr<StreamingRPCState<Response>> state_ TF_GUARDED_BY(mu_); 730 }; 731 732 } // namespace tensorflow 733 734 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 735