• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
18 
19 #include <queue>
20 #include <utility>
21 
22 #include "grpcpp/generic/generic_stub.h"
23 #include "grpcpp/grpcpp.h"
24 #include "tensorflow/core/distributed_runtime/call_options.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/notification.h"
34 #include "tensorflow/core/util/env_var.h"
35 
36 namespace tensorflow {
37 
38 // Object allocated per active RPC.
39 // Manage the state of a single asynchronous RPC request.  If `max_retries`
40 // is greater than 0, the request will be retried for any transient failures.
41 template <class Response>
42 class RPCState : public GrpcClientCQTag {
43  public:
44   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
45            const ::grpc::string& method, const protobuf::Message& request,
46            Response* response, StatusCallback done, CallOptions* call_opts,
47            thread::ThreadPool* threadpool, int32_t max_retries = 0,
48            bool fail_fast = true, const string* target = nullptr)
49       : RPCState(
50             stub, cq, method, request, response, std::move(done), call_opts,
51             threadpool,
52             // 1) If GRPC_FAIL_FAST is set to 'true' or 'false',
53             // fail_fast=$GRPC_FAIL_FAST. See b/141948186.
54             // 2) Otherwise if GRPC_FAIL_FAST is set to 'use_caller', use the
55             // fail_fast from the caller. See b/140260119.
56             //
57             // Current default: use caller's fail_fast argument.
58             //
59             // NOTE: Callers mostly set fail_fast=true to prevent job hanging
60             // on worker task failures, except a few cases such as GetStatus
61             // in cluster initialization and collective param resolution.
62             [fail_fast, &done]() -> bool {
63               string fail_fast_env;
64               TF_CHECK_OK(ReadStringFromEnvVar("GRPC_FAIL_FAST", "use_caller",
65                                                &fail_fast_env));
66               string fail_fast_env_lower = absl::AsciiStrToLower(fail_fast_env);
67               if (fail_fast_env_lower == "true") {
68                 return true;
69               } else if (fail_fast_env_lower == "use_caller") {
70                 return fail_fast;
71               } else if (fail_fast_env_lower == "false") {
72                 return false;
73               } else {
74                 string error_message = strings::StrCat(
75                     "Invalid GRPC_FAIL_FAST config: ", fail_fast_env);
76                 LOG(WARNING) << error_message;
77                 done(errors::InvalidArgument(error_message));
78                 return false;
79               }
80             }(),
81             (call_opts != nullptr ? call_opts->GetTimeout() : 0), max_retries,
82             target) {}
83 
84   template <typename Request>
RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64_t timeout_in_ms,int32_t max_retries,const string * target)85   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
86            const ::grpc::string& method, const Request& request,
87            Response* response, StatusCallback done, CallOptions* call_opts,
88            thread::ThreadPool* threadpool, bool fail_fast,
89            int64_t timeout_in_ms, int32_t max_retries, const string* target)
90       : call_opts_(call_opts),
91         threadpool_(threadpool),
92         done_(std::move(done)),
93         timeout_in_ms_(timeout_in_ms),
94         max_retries_(max_retries),
95         cq_(cq),
96         stub_(stub),
97         method_(method),
98         fail_fast_(fail_fast),
99         target_(target) {
100     response_ = response;
101     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
102     if (!s.ok()) {
103       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
104                  << s.error_message();
105       // Skip retry logic if we fail to parse our request.
106       done_(FromGrpcStatus(s));
107       delete this;
108       return;
109     }
110     StartCall();
111   }
112 
StartCall()113   void StartCall() {
114     context_.reset(new ::grpc::ClientContext());
115     context_->set_wait_for_ready(!fail_fast_);
116     if (timeout_in_ms_ > 0) {
117       context_->set_deadline(
118           gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
119     }
120     if (call_opts_) {
121       call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
122     }
123 
124     VLOG(2) << "Starting call: " << method_;
125 
126     call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_);
127     call_->StartCall();
128     call_->Finish(&response_buf_, &status_, this);
129   }
130 
OnCompleted(bool ok)131   void OnCompleted(bool ok) override {
132     if (call_opts_) {
133       call_opts_->ClearCancelCallback();
134     }
135 
136     VLOG(2) << "Completed call: " << method_;
137 
138     Status s = FromGrpcStatus(status_);
139     if (s.ok() && !ok) {
140       // Since this function is only being used for processing the response
141       // to Finish for client-side unary calls, ok should never be false
142       s.Update(
143           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
144                            "not.  This should never happen."));
145     }
146 
147     if (s.ok()) {
148       if (threadpool_) {
149         // Run parse and callback in another thread, returning this
150         // one to service more RPCs.
151         threadpool_->Schedule([this]() { ParseAndCallDone(); });
152       } else {
153         ParseAndCallDone();
154       }
155       return;
156     }
157 
158     VLOG(1) << method_ << " returned with non-ok status: " << s
159             << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n"
160             << context_->debug_error_string();
161     // Retry if we have any attempts left
162     if (++num_retries_ <= max_retries_ &&
163         (errors::IsUnavailable(s) || errors::IsUnknown(s))) {
164       response_buf_.Clear();
165       VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_
166               << " of " << max_retries_;
167 
168       ComputeRetryBackoffMs(/*min_backoff_ms=*/1, /*max_backoff_ms=*/10000);
169       int64_t backoff_us = retry_backoff_ms_ * 1000;
170       Env::Default()->SchedClosureAfter(/*micros=*/backoff_us,
171                                         [this]() { StartCall(); });
172     } else {
173       // Attach additional GRPC error information if any to the final status
174       string error_msg = s.error_message();
175       strings::StrAppend(&error_msg, "\nAdditional GRPC error information");
176       if (target_) {
177         strings::StrAppend(&error_msg, " from remote target ", *target_);
178       }
179       strings::StrAppend(&error_msg, ":\n:", context_->debug_error_string());
180       s = Status(s.code(), error_msg);
181       // Always treat gRPC cancellation as a derived error. This ensures that
182       // other error types are preferred during status aggregation. (gRPC
183       // cancellation messages do not contain the original status message).
184       if (s.code() == tensorflow::error::Code::CANCELLED) {
185         s = StatusGroup::MakeDerived(s);
186       }
187 
188       done_(s);
189       delete this;
190     }
191   }
192 
ParseAndCallDone()193   void ParseAndCallDone() {
194     Status s;
195     if (!GrpcMaybeParseProto(&response_buf_, response_)) {
196       s.Update(errors::Internal("could not parse rpc response"));
197     }
198     done_(s);
199     delete this;
200   }
201 
202  private:
ComputeRetryBackoffMs(int min_backoff_ms,int max_backoff_ms)203   void ComputeRetryBackoffMs(int min_backoff_ms, int max_backoff_ms) {
204     constexpr float kBackoffBase = 1.3;
205     if (retry_backoff_ms_ < 0) {
206       retry_backoff_ms_ = min_backoff_ms;
207     } else {
208       retry_backoff_ms_ *= kBackoffBase;
209       if (retry_backoff_ms_ > max_backoff_ms) {
210         retry_backoff_ms_ = max_backoff_ms;
211       }
212     }
213   }
214 
215   CallOptions* call_opts_;
216   std::unique_ptr<::grpc::ClientContext> context_;
217   thread::ThreadPool* threadpool_;
218   std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;
219   Response* response_;
220   ::grpc::ByteBuffer request_buf_;
221   ::grpc::ByteBuffer response_buf_;
222   ::grpc::Status status_;
223   StatusCallback done_;
224   int64 timeout_in_ms_;
225 
226   size_t num_retries_ = 0;
227   size_t max_retries_;
228   double retry_backoff_ms_ = -1;
229 
230   ::grpc::CompletionQueue* cq_;
231   ::grpc::GenericStub* stub_;
232   ::grpc::string method_;
233   bool fail_fast_;
234   const string* target_;
235 };
236 
237 // Represents state associated with one streaming RPC call.
238 // Similarly to above, we extract the methods of StreamingRPCState that don't
239 // need to be templated into this abstract class.
240 // Currently, *StreamingRPCState does not support client closing the call as
241 // there is no use case for it - current clients keep the streaming call open
242 // as long as possible. If/when the need arises, support can be added
243 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag
244 // TagType::kClientFinished and handling the completion in a new callback.
245 class UntypedStreamingRPCState : public core::RefCounted {
246  public:
247   virtual void CallStarted(bool ok) = 0;
248   virtual void RequestWriteCompleted(bool ok) = 0;
249   virtual void ResponseReadCompleted(bool ok) = 0;
250   virtual void CallFinished(bool ok) = 0;
251 
252   virtual string DebugString() const = 0;
253 
254   class Tag : public GrpcClientCQTag {
255    public:
256     // One enum value per supported callback.
257     enum class TagType {
258       kCallStarted,
259       kRequestWriteCompleted,
260       kResponseReadCompleted,
261       kCallFinished,
262     };
263 
264     Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
265 
266     // Calls the callback associated with this tag and Unrefs
267     // `this->streaming_state_`.
268     void OnCompleted(bool ok) override;
269 
270    private:
271     // OnCompleted() consumes on reference each time it is called.
272     UntypedStreamingRPCState* const streaming_state_;
273     const Tag::TagType type_;
274   };
275 };
276 
277 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type);
278 
279 // Represents a single request/response exchange between client and the server.
280 // A single streaming call contains a sequence of exchanges. Besides the
281 // messages, exchange contains:
282 //  - the user callback to invoke when exchange completes (response is received
283 //    or an error occurs).
284 //  - The current state of the exchange.
285 class Exchange {
286  public:
287   enum class State {
288     kExchangeCreated,
289     kRequestWriteIssued,
290     kRequestWriteCompleted,
291     kResponseReadIssued,
292   };
293 
Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)294   Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response,
295            StatusCallback cb, string debug_string)
296       : state_(State::kExchangeCreated),
297         request_buf_(request_buf),
298         response_(response),
299         cb_(std::move(cb)),
300         debug_string_(std::move(debug_string)) {}
301 
request_buf()302   const ::grpc::ByteBuffer& request_buf() { return request_buf_; }
response_buf()303   ::grpc::ByteBuffer* response_buf() { return &response_buf_; }
304 
MarkRequestWriteIssued()305   void MarkRequestWriteIssued() {
306     DCHECK(state_ == State::kExchangeCreated);
307     state_ = State::kRequestWriteIssued;
308   }
MarkRequestWriteCompleted()309   void MarkRequestWriteCompleted() {
310     DCHECK(state_ == State::kRequestWriteIssued);
311     state_ = State::kRequestWriteCompleted;
312   }
MarkResponseReadIssued()313   void MarkResponseReadIssued() {
314     DCHECK(state_ == State::kRequestWriteCompleted);
315     state_ = State::kResponseReadIssued;
316   }
317 
318   // If `status` is success, completes this exchange by parsing the
319   // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the
320   // callback with `status`.
321   void Complete(Status status);
322 
state()323   const State& state() const { return state_; }
324 
325   string DebugString() const;
326 
327  private:
328   State state_;
329   ::grpc::ByteBuffer request_buf_;
330   ::grpc::ByteBuffer response_buf_;
331   protobuf::Message* response_;
332   StatusCallback cb_;
333   string debug_string_;
334 };
335 
336 const char* ToString(Exchange::State s);
337 
338 std::ostream& operator<<(std::ostream& os, const Exchange::State& state);
339 
340 // Represents a queue of exchanges.
341 // When a client sends a new request a new exchange is created and added to the
342 // end of the queue. Completed exchanges are popped from the front of the queue.
343 // An explicit exchange queue is needed to brdige the client, which can send new
344 // requests at any time, with gRPC infrastructure, which can handle a single
345 // read and a single write request at a time.
346 //
347 // As the exchange progresses (request sending initiated, request sending
348 // completed, response reading initiated) the queue helps to make sure that the
349 // right operation is issued on the right exchange at the right time.
350 //
351 // To satisfy gRPC constraints, the states of exchanges must be as follows
352 // starting from the front of the queue:
353 //  - 0 or 1 exchange in kResponseReadIssued state
354 //  - 0 or more exchanges in kRequestWriteCompleted state
355 //  - 0 or 1 exchange in kRequestWriteIssued state
356 //  - 0 or more exchanges in kExchangeCreated state
357 //
358 // Thread-compatible.
359 class ExchangeQueue {
360  public:
361   // Creates a new exchange and adds it to the end of the queue.
362   void Emplace(const ::grpc::ByteBuffer& request_buf,
363                protobuf::Message* response, StatusCallback cb,
364                std::string debug_string);
365 
366   // Returns an exchange for which we can initiate request writing, if any.
367   // Returns nullptr if there is no such exchange.
368   Exchange* GetReadyForRequestWriting();
369 
370   // Returns an exchange for which we can initiate response reading, if any.
371   // Returns nullptr if there is no such exchange.
372   Exchange* GetReadyForResponseReading();
373 
374   // Changes the state of the exchange that is current in kRequestWriteIssued
375   // state to kRequestWriteCompleted state.
376   // REQUIRES: There is an exchange in kRequestWriteIssued state.
377   void MarkRequestWriteCompleted();
378 
379   // Returns the exchange at the front of the queue.
380   // REQUIRES: ExchangeQueue is not empty.
381   Exchange& GetFront();
382 
383   // Removes the exchange at the front of the queue.
384   // REQUIRES: ExchangeQueue is not empty.
385   void PopFront();
386 
387   // Returns a string containing addresses and states of all exchanges in this
388   // queue.
389   string DebugString() const;
390 
391   // Swaps the contents of this and `other`.
392   void Swap(ExchangeQueue* other);
393 
394   // Completes all exchanges in this with `status`.
395   void CompleteAll(Status status);
396 
CallStarted()397   void CallStarted() { call_started_ = true; }
398 
399  private:
400   // Does nothing by default. Turn on VLOG(5) to enable.
401   // Checks that this ExchangeQueue is in a valid state.
402   // Kills the process if not.
403   void CheckInvariants();
404 
405   // We can't process any exchanges until the call has started.
406   bool call_started_ = false;
407 
408   // std::queue is based on std::deque by default. std::deque provides
409   // fairly strong iterator stability.
410   std::deque<Exchange> exchanges_;
411 };  // namespace tensorflow
412 
413 // Represents state associated with one streaming RPC call.
414 // Thread-safe
415 template <class Response>
416 class StreamingRPCState : public UntypedStreamingRPCState {
417  public:
418   // Default behavior is to set fail_fast = False and handle timeouts
419   // manually.
StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)420   StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
421                     const std::shared_ptr<::grpc::ClientContext>& context)
422       : context_(context), call_(std::move(call)), call_state_(State::kActive) {
423     Ref();
424     VLOG(3) << "Created new StreamingRPCState " << this;
425     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall";
426     call_->StartCall(&call_started_tag_);
427   }
428 
~StreamingRPCState()429   ~StreamingRPCState() override {
430     VLOG(3) << "Destructing StreamingRPCState " << this;
431   }
432 
433   // Attempts to send the next request. `done` is invoked when
434   // `response` has been filled with the data from the server, or if there
435   // is an error. `done` can be invoked before SendNextRequest returns.
436   // Return `true` if the call is alive and the `done` callback has or
437   // will be invoked. If the call is dead, returns `false`. `done` callback
438   // will not be invoked in this case.
439   // REQUIRES: The call has been started, i.e. WaitForCallStarted() has
440   // returned.
SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)441   bool SendNextRequest(const protobuf::Message& request, Response* response,
442                        const StatusCallback& done) {
443     ::grpc::ByteBuffer request_buf;
444     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf);
445     if (!s.ok()) {
446       Status status = FromGrpcStatus(s);
447       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
448                  << status.ToString();
449       done(status);
450       return true;
451     }
452 
453     mutex_lock l(mu_);
454     if (call_state_ != State::kActive) {
455       // `done` is not invoked intentionally.
456       return false;
457     }
458     if (VLOG_IS_ON(3)) {
459       // If vlog 3 is enabled, include first 100 chars of request as debug
460       // string.
461       exchanges_.Emplace(request_buf, response, done,
462                          request.ShortDebugString().substr(0, 100));
463     } else {
464       exchanges_.Emplace(request_buf, response, done, "");
465     }
466     MaybeIssueRequestWriteLocked();
467     return true;
468   }
469 
CallStarted(bool ok)470   void CallStarted(bool ok) override {
471     VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok
472             << ")";
473     mutex_lock l(mu_);
474     if (!ok) {
475       call_state_ = State::kDone;
476       return;
477     }
478     exchanges_.CallStarted();
479     // Now that the call has started, we can write our first request, if any.
480     MaybeIssueRequestWriteLocked();
481   }
482 
RequestWriteCompleted(bool ok)483   void RequestWriteCompleted(bool ok) override {
484     VLOG(3) << "StreamingRPCState(" << this
485             << ")::RequestWriteCompleted(ok=" << ok << ")";
486     mu_.lock();
487     if (call_state_ != State::kActive) {
488       mu_.unlock();
489       return;
490     }
491     exchanges_.MarkRequestWriteCompleted();
492     // Issue ResponseRead regardless of OK status on completing RequestWrite.
493     // If the underlying completion queue is in Not-OK status due to previous
494     // request failuress (i.e., `ok` from `Next` call on completion queue is
495     // False), delay the error in ResponseRead so we can get the remote error
496     // message from response buffer.
497     MaybeIssueResponseReadLocked();
498 
499     if (ok) {
500       MaybeIssueRequestWriteLocked();
501     }
502     mu_.unlock();
503   }
504 
ResponseReadCompleted(bool ok)505   void ResponseReadCompleted(bool ok) override {
506     VLOG(3) << "StreamingRPCState(" << this
507             << ")::ResponseReadCompleted(ok=" << ok << ")";
508     mu_.lock();
509     if (call_state_ != State::kActive) {
510       mu_.unlock();
511       return;
512     }
513     if (!ok) {
514       IssueCallFinishLocked();
515       mu_.unlock();
516       return;
517     }
518 
519     // Complete the exchange without holding the lock because user's
520     // callback can call back into this RPC code resulting in a deadlock.
521     // No other thread can pop this exchange while we release the lock because
522     // this is the only method that pops exchanges and it is called from a
523     // single thread that waits on completion queue events.
524     Exchange* e;
525     e = &exchanges_.GetFront();
526     mu_.unlock();
527 
528     e->Complete(Status::OK());
529 
530     {
531       mutex_lock l(mu_);
532       exchanges_.PopFront();
533       MaybeIssueResponseReadLocked();
534     }
535   }
536 
CallFinished(bool ok)537   void CallFinished(bool ok) override {
538     VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok
539             << ")";
540     mu_.lock();
541     DCHECK(call_state_ != State::kActive);
542     if (call_state_ != State::kFinishing) {
543       mu_.unlock();
544       return;
545     }
546 
547     Status s = FromGrpcStatus(call_status_);
548     if (s.ok() && !ok) {
549       s.Update(
550           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
551                            "not.  This should never happen.",
552                            context_->debug_error_string()));
553     }
554     // unlocks mu_
555     MarkDoneAndCompleteExchanges(s);
556   }
557 
DebugString()558   string DebugString() const override {
559     mutex_lock l(mu_);
560     return exchanges_.DebugString();
561   }
562 
563  private:
564   enum class State {
565     kActive,
566     kFinishing,
567     kDone,
568   };
569 
MarkDoneAndCompleteExchanges(Status status)570   void MarkDoneAndCompleteExchanges(Status status)
571       TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) TF_UNLOCK_FUNCTION(mu_) {
572     call_state_ = State::kDone;
573     VLOG(2) << "Ending gRPC streaming call on the client side due to "
574             << status.ToString();
575     // Swap the exchanges_ into a temporary ExchangeQueue so that we can
576     // complete all exchanges without holding mu_ in case user callback
577     // reach back into this. This should be impossible now, but safer for
578     // the future.
579     ExchangeQueue queue;
580     exchanges_.Swap(&queue);
581     mu_.unlock();
582     queue.CompleteAll(status);
583   }
584 
MaybeIssueRequestWriteLocked()585   void MaybeIssueRequestWriteLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
586     Exchange* exchange = exchanges_.GetReadyForRequestWriting();
587     if (exchange == nullptr) {
588       // There are no queued exchanges, there is already an outstanding write,
589       // or there are no just created exchanges.
590       return;
591     }
592     exchange->MarkRequestWriteIssued();
593     Ref();
594     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write";
595     call_->Write(exchange->request_buf(), &request_write_completed_tag_);
596   }
597 
MaybeIssueResponseReadLocked()598   void MaybeIssueResponseReadLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
599     Exchange* exchange = exchanges_.GetReadyForResponseReading();
600     if (exchange == nullptr) {
601       return;
602     }
603     exchange->MarkResponseReadIssued();
604     Ref();
605     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read";
606     call_->Read(exchange->response_buf(), &response_read_completed_tag_);
607   }
608 
IssueCallFinishLocked()609   void IssueCallFinishLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
610     call_state_ = State::kFinishing;
611     Ref();
612     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish";
613     // We call finish in response to completed (with error) response reading tag
614     // on some exchange. We let this exchange hang in ResponseReadIssued state.
615     // ExchangeQueue makes sure that there is at most one exchange in this
616     // state. So, no new reads will be issued.
617     call_->Finish(&call_status_, &finished_tag_);
618   }
619 
620   // Holds state for a single request/response exchange between the client
621   // and the server.
622   typedef typename UntypedStreamingRPCState::Tag Tag;
623 
624   // Order of context_ and call_ is important because context_ must outlive
625   // call_.
626   const std::shared_ptr<const ::grpc::ClientContext> context_;
627   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
628 
629   mutable mutex mu_;
630   ExchangeQueue exchanges_ TF_GUARDED_BY(mu_);
631   State call_state_ TF_GUARDED_BY(mu_);
632   ::grpc::Status call_status_ TF_GUARDED_BY(mu_);
633 
634   // We can get away with having single instances of these tags per
635   // StreamingRPCState because we make sure (as gRPC requires) that
636   // there is at most one outstanding Read and at most one outstanding Write
637   // in the completion queue.
638   // Tags are immutable. No need to guard them.
639   Tag call_started_tag_{this, Tag::TagType::kCallStarted};
640   Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
641   Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCompleted};
642   Tag finished_tag_{this, Tag::TagType::kCallFinished};
643 };
644 
645 // Creates streaming calls and dispatches requests to them.
646 // In the common case, the client would create a StreamingRPCDispatcher for
647 // each bidirectional streaming RPC it might want to make. The first time, it
648 // calls SendNextRequest, a streaming call is initiated and the request is
649 // sent within this call. Initiation of the call blocks the client. If there are
650 // no errors, subsequent calls to SendNextRequest would use the already active
651 // call. If there was an error, the call object will be destroyed after all
652 // the callbacks for outstanding requests have been invoked. The next call to
653 // SendNextRequest will initiate a new call.
654 //
655 // Callbacks that are part of the same call, are invoked in the order they were
656 // provided, but callbacks across calls (a failed and a new one) can be invoked
657 // in any order.
658 //
659 // Thread-safe.
660 template <class Response>
661 class StreamingRPCDispatcher {
662  public:
StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)663   StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
664                          const ::grpc::string& method)
665       : stub_(stub), cq_(cq), method_(method) {}
666 
667   // Attempts to send the next request. If there is no active streaming call,
668   // starts one and sends the request on top of it. `done` is invoked when
669   // `response` has been filled with the data from the server, or if there
670   // is an error. `done` can be invoked before SendNextRequest returns.
SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)671   void SendNextRequest(const protobuf::Message& request, Response* response,
672                        StatusCallback done) {
673     mutex_lock l(mu_);
674     if (state_ == nullptr) {
675       CreateStreamingState();
676     }
677 
678     bool is_call_alive = state_->SendNextRequest(request, response, done);
679     if (is_call_alive) {
680       return;
681     }
682 
683     // The attempt to send failed because the call was dead, create a new
684     // call and try again. When the call is dead SendNextRequest does not call
685     // `done`.
686     CreateStreamingState();
687 
688     is_call_alive = state_->SendNextRequest(request, response, done);
689     if (!is_call_alive) {
690       // Consider retrying to create and start a call few more times.
691       done(errors::Unknown("gRPC call failed right after it was created"));
692     }
693   }
694 
695   // Request to cancel the current streaming call. Non-blocking.
CancelCall()696   void CancelCall() {
697     mutex_lock l(mu_);
698     if (state_ == nullptr) {
699       return;
700     }
701     context_->TryCancel();
702     state_ = nullptr;
703   }
704 
705  private:
CreateStreamingState()706   void CreateStreamingState() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
707     // ClientContext cannot be reused across calls.
708     context_ = std::make_shared<::grpc::ClientContext>();
709     // Don't immediately fail StartCall if the channel is not ready. Wait for
710     // the channel to become ready.
711     context_->set_wait_for_ready(true);
712 
713     std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
714         stub_->PrepareCall(context_.get(), method_, cq_);
715 
716     state_.reset(new StreamingRPCState<Response>(std::move(call), context_));
717   }
718 
719   mutable mutex mu_;
720 
721   // Both are thread-safe
722   ::grpc::GenericStub* const stub_;
723   ::grpc::CompletionQueue* const cq_;
724 
725   // Does not need synchronization since it is constant.
726   const ::grpc::string method_;
727 
728   std::shared_ptr<::grpc::ClientContext> context_ TF_GUARDED_BY(mu_);
729   core::RefCountPtr<StreamingRPCState<Response>> state_ TF_GUARDED_BY(mu_);
730 };
731 
732 }  // namespace tensorflow
733 
734 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
735