• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
18 
19 #include <queue>
20 #include <utility>
21 
22 #include "grpcpp/generic/generic_stub.h"
23 #include "grpcpp/grpcpp.h"
24 #include "tensorflow/core/distributed_runtime/call_options.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
28 #include "tensorflow/core/lib/core/refcount.h"
29 #include "tensorflow/core/lib/core/status.h"
30 #include "tensorflow/core/lib/core/threadpool.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/notification.h"
34 #include "tensorflow/core/util/env_var.h"
35 
36 namespace tensorflow {
37 
38 // Object allocated per active RPC.
39 // Manage the state of a single asynchronous RPC request.  If `max_retries`
40 // is greater than 0, the request will be retried for any transient failures.
41 template <class Response>
42 class RPCState : public GrpcClientCQTag {
43  public:
44   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
45            const ::grpc::string& method, const protobuf::Message& request,
46            Response* response, StatusCallback done, CallOptions* call_opts,
47            thread::ThreadPool* threadpool, int32 max_retries = 0,
48            bool fail_fast = true)
49       : RPCState(
50             stub, cq, method, request, response, std::move(done), call_opts,
51             threadpool,
52             // 1) If GRPC_FAIL_FAST is specified, fail_fast=$GRPC_FAIL_FAST.
53             // See b/141948186.
54             // 2) Otherwise, if the platform is Google, use the fail_fast from
55             // the caller. See b/140260119.
56             // 3) Otherwise, use fail_fast=false.
57             [fail_fast]() -> bool {
58               bool x;
59 #if defined(PLATFORM_GOOGLE)
60               TF_CHECK_OK(ReadBoolFromEnvVar("GRPC_FAIL_FAST", fail_fast, &x));
61 #else
62               TF_CHECK_OK(ReadBoolFromEnvVar("GRPC_FAIL_FAST", false, &x));
63 #endif  // PLATFORM_GOOGLE
64               return x;
65             }(),
66             /*timeout_in_ms=*/0, max_retries) {
67   }
68 
69   template <typename Request>
RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64 timeout_in_ms,int32 max_retries)70   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
71            const ::grpc::string& method, const Request& request,
72            Response* response, StatusCallback done, CallOptions* call_opts,
73            thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms,
74            int32 max_retries)
75       : call_opts_(call_opts),
76         threadpool_(threadpool),
77         done_(std::move(done)),
78         timeout_in_ms_(timeout_in_ms),
79         max_retries_(max_retries),
80         cq_(cq),
81         stub_(stub),
82         method_(method),
83         fail_fast_(fail_fast) {
84     response_ = response;
85     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
86     if (!s.ok()) {
87       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
88                  << s.error_message();
89       // Skip retry logic if we fail to parse our request.
90       done_(FromGrpcStatus(s));
91       delete this;
92       return;
93     }
94     StartCall();
95   }
96 
StartCall()97   void StartCall() {
98     context_.reset(new ::grpc::ClientContext());
99     context_->set_wait_for_ready(!fail_fast_);
100     if (timeout_in_ms_ > 0) {
101       context_->set_deadline(
102           gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
103     }
104     if (call_opts_) {
105       call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
106     }
107 
108     VLOG(2) << "Starting call: " << method_;
109 
110     call_ = stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_);
111     call_->StartCall();
112     call_->Finish(&response_buf_, &status_, this);
113   }
114 
OnCompleted(bool ok)115   void OnCompleted(bool ok) override {
116     if (call_opts_) {
117       call_opts_->ClearCancelCallback();
118     }
119 
120     VLOG(2) << "Completed call: " << method_;
121 
122     Status s = FromGrpcStatus(status_);
123     if (s.ok() && !ok) {
124       // Since this function is only being used for processing the response
125       // to Finish for client-side unary calls, ok should never be false
126       s.Update(
127           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
128                            "not.  This should never happen."));
129     }
130 
131     if (s.ok()) {
132       if (threadpool_) {
133         // Run parse and callback in another thread, returning this
134         // one to service more RPCs.
135         threadpool_->Schedule([this]() { ParseAndCallDone(); });
136       } else {
137         ParseAndCallDone();
138       }
139       return;
140     }
141 
142     VLOG(1) << method_ << " returned with non-ok status: " << s
143             << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n"
144             << context_->debug_error_string();
145     // Retry if we have any attempts left
146     if (++num_retries_ <= max_retries_ &&
147         (errors::IsUnavailable(s) || errors::IsUnknown(s))) {
148       response_buf_.Clear();
149       VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_
150               << " of " << max_retries_;
151       // TODO(b/139945426) Allow user to configure the retry backoff time.
152       StartCall();
153     } else {
154       // Attach additional GRPC error information if any to the final status
155       s = Status(s.code(),
156                  strings::StrCat(s.error_message(),
157                                  "\nAdditional GRPC error information:\n",
158                                  context_->debug_error_string()));
159       // Always treat gRPC cancellation as a derived error. This ensures that
160       // other error types are preferred during status aggregation. (gRPC
161       // cancellation messages do not contain the original status message).
162       if (s.code() == tensorflow::error::Code::CANCELLED) {
163         s = StatusGroup::MakeDerived(s);
164       }
165 
166       done_(s);
167       delete this;
168     }
169   }
170 
ParseAndCallDone()171   void ParseAndCallDone() {
172     Status s;
173     if (!GrpcMaybeParseProto(&response_buf_, response_)) {
174       s.Update(errors::Internal("could not parse rpc response"));
175     }
176     done_(s);
177     delete this;
178   }
179 
180  private:
181   CallOptions* call_opts_;
182   std::unique_ptr<::grpc::ClientContext> context_;
183   thread::ThreadPool* threadpool_;
184   std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;
185   Response* response_;
186   ::grpc::ByteBuffer request_buf_;
187   ::grpc::ByteBuffer response_buf_;
188   ::grpc::Status status_;
189   StatusCallback done_;
190   int64 timeout_in_ms_;
191 
192   size_t num_retries_ = 0;
193   size_t max_retries_;
194 
195   ::grpc::CompletionQueue* cq_;
196   ::grpc::GenericStub* stub_;
197   ::grpc::string method_;
198   bool fail_fast_;
199 };
200 
201 // Represents state associated with one streaming RPC call.
202 // Similarly to above, we extract the methods of StreamingRPCState that don't
203 // need to be templated into this abstract class.
204 // Currently, *StreamingRPCState does not support client closing the call as
205 // there is no use case for it - current clients keep the streaming call open
206 // as long as possible. If/when the need arises, support can be added
207 // by calling GenericClientAsyncReaderWriter::WritesDone with a new tag
208 // TagType::kClientFinished and handling the completion in a new callback.
209 class UntypedStreamingRPCState : public core::RefCounted {
210  public:
211   virtual void CallStarted(bool ok) = 0;
212   virtual void RequestWriteCompleted(bool ok) = 0;
213   virtual void ResponseReadCompleted(bool ok) = 0;
214   virtual void CallFinished(bool ok) = 0;
215 
216   virtual string DebugString() const = 0;
217 
218   class Tag : public GrpcClientCQTag {
219    public:
220     // One enum value per supported callback.
221     enum class TagType {
222       kCallStarted,
223       kRequestWriteCompleted,
224       kResponseReadCommpleted,
225       kCallFinished,
226     };
227 
228     Tag(UntypedStreamingRPCState* streaming_state, Tag::TagType type);
229 
230     // Calls the callback associated with this tag and Unrefs
231     // `this->streaming_state_`.
232     void OnCompleted(bool ok) override;
233 
234    private:
235     // OnCompleted() consumes on reference each time it is called.
236     UntypedStreamingRPCState* const streaming_state_;
237     const Tag::TagType type_;
238   };
239 };
240 
241 const char* ToString(UntypedStreamingRPCState::Tag::TagType tag_type);
242 
243 // Represents a single request/response exchange between client and the server.
244 // A single streaming call contains a sequence of exchanges. Besides the
245 // messages, exchange contains:
246 //  - the user callback to invoke when exchange completes (response is received
247 //    or an error occurs).
248 //  - The current state of the exchange.
249 class Exchange {
250  public:
251   enum class State {
252     kExchangeCreated,
253     kRequestWriteIssued,
254     kRequestWriteCompleted,
255     kResponseReadIssued,
256   };
257 
Exchange(const::grpc::ByteBuffer & request_buf,protobuf::Message * response,StatusCallback cb,string debug_string)258   Exchange(const ::grpc::ByteBuffer& request_buf, protobuf::Message* response,
259            StatusCallback cb, string debug_string)
260       : state_(State::kExchangeCreated),
261         request_buf_(request_buf),
262         response_(response),
263         cb_(std::move(cb)),
264         debug_string_(std::move(debug_string)) {}
265 
request_buf()266   const ::grpc::ByteBuffer& request_buf() { return request_buf_; }
response_buf()267   ::grpc::ByteBuffer* response_buf() { return &response_buf_; }
268 
MarkRequestWriteIssued()269   void MarkRequestWriteIssued() {
270     DCHECK(state_ == State::kExchangeCreated);
271     state_ = State::kRequestWriteIssued;
272   }
MarkRequestWriteCompleted()273   void MarkRequestWriteCompleted() {
274     DCHECK(state_ == State::kRequestWriteIssued);
275     state_ = State::kRequestWriteCompleted;
276   }
MarkResponseReadIssued()277   void MarkResponseReadIssued() {
278     DCHECK(state_ == State::kRequestWriteCompleted);
279     state_ = State::kResponseReadIssued;
280   }
281 
282   // If `status` is success, completes this exchange by parsing the
283   // response_buf_ and invoking cb_ with Status::OK(). Else, invokes the
284   // callback with `status`.
285   void Complete(Status status);
286 
state()287   const State& state() const { return state_; }
288 
289   string DebugString() const;
290 
291  private:
292   State state_;
293   ::grpc::ByteBuffer request_buf_;
294   ::grpc::ByteBuffer response_buf_;
295   protobuf::Message* response_;
296   StatusCallback cb_;
297   string debug_string_;
298 };
299 
300 const char* ToString(Exchange::State s);
301 
302 std::ostream& operator<<(std::ostream& os, const Exchange::State& state);
303 
304 // Represents a queue of exchanges.
305 // When a client sends a new request a new exchange is created and added to the
306 // end of the queue. Completed exchanges are popped from the front of the queue.
307 // An explicit exchange queue is needed to brdige the client, which can send new
308 // requests at any time, with gRPC infrastructure, which can handle a single
309 // read and a single write request at a time.
310 //
311 // As the exchange progresses (request sending initiated, request sending
312 // completed, response reading initiated) the queue helps to make sure that the
313 // right operation is issued on the right exchange at the right time.
314 //
315 // To satisfy gRPC constraints, the states of exchanges must be as follows
316 // starting from the front of the queue:
317 //  - 0 or 1 exchange in kResponseReadIssued state
318 //  - 0 or more exchanges in kRequestWriteCompleted state
319 //  - 0 or 1 exchange in kRequestWriteIssued state
320 //  - 0 or more exchanges in kExchangeCreated state
321 //
322 // Thread-compatible.
323 class ExchangeQueue {
324  public:
325   // Creates a new exchange and adds it to the end of the queue.
326   void Emplace(const ::grpc::ByteBuffer& request_buf,
327                protobuf::Message* response, StatusCallback cb,
328                std::string debug_string);
329 
330   // Returns an exchange for which we can initiate request writing, if any.
331   // Returns nullptr if there is no such exchange.
332   Exchange* GetReadyForRequestWriting();
333 
334   // Returns an exchange for which we can initiate response reading, if any.
335   // Returns nullptr if there is no such exchange.
336   Exchange* GetReadyForResponseReading();
337 
338   // Changes the state of the exchange that is current in kRequestWriteIssued
339   // state to kRequestWriteCompleted state.
340   // REQUIRES: There is an exhange in kRequestWriteIssued state.
341   void MarkRequestWriteCompleted();
342 
343   // Returns the exchange at the front of the queue.
344   // REQUIRES: ExchangeQueue is not empty.
345   Exchange& GetFront();
346 
347   // Removes the exchange at the front of the queue.
348   // REQUIRES: ExchangeQueue is not empty.
349   void PopFront();
350 
351   // Returns a string containing addresses and states of all exchanges in this
352   // queue.
353   string DebugString() const;
354 
355   // Swaps the contents of this and `other`.
356   void Swap(ExchangeQueue* other);
357 
358   // Completes all exchanges in this with `status`.
359   void CompleteAll(Status status);
360 
CallStarted()361   void CallStarted() { call_started_ = true; }
362 
363  private:
364   // Does nothing by default. Turn on VLOG(5) to enable.
365   // Checks that this ExchangeQueue is in a valid state.
366   // Kills the process if not.
367   void CheckInvariants();
368 
369   // We can't process any exchanges until the call has started.
370   bool call_started_ = false;
371 
372   // std::queue is based on std::deque by default. std::deque provides
373   // fairly strong iterator stability.
374   std::deque<Exchange> exchanges_;
375 };  // namespace tensorflow
376 
377 // Represents state associated with one streaming RPC call.
378 // Thread-safe
379 template <class Response>
380 class StreamingRPCState : public UntypedStreamingRPCState {
381  public:
382   // Default behavior is to set fail_fast = False and handle timeouts
383   // manually.
StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,const std::shared_ptr<::grpc::ClientContext> & context)384   StreamingRPCState(std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call,
385                     const std::shared_ptr<::grpc::ClientContext>& context)
386       : context_(context), call_(std::move(call)), call_state_(State::kActive) {
387     Ref();
388     VLOG(3) << "Created new StreamingRPCState " << this;
389     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::StartCall";
390     call_->StartCall(&call_started_tag_);
391   }
392 
~StreamingRPCState()393   ~StreamingRPCState() override {
394     VLOG(3) << "Destructing StreamingRPCState " << this;
395   }
396 
397   // Attempts to send the next request. `done` is invoked when
398   // `response` has been filled with the data from the server, or if there
399   // is an error. `done` can be invoked before SendNextRequest returns.
400   // Return `true` if the call is alive and the `done` callback has or
401   // will be invoked. If the call is dead, returns `false`. `done` callback
402   // will not be invoked in this case.
403   // REQUIRES: The call has been started, i.e. WaitForCallStarted() has
404   // returned.
SendNextRequest(const protobuf::Message & request,Response * response,const StatusCallback & done)405   bool SendNextRequest(const protobuf::Message& request, Response* response,
406                        const StatusCallback& done) {
407     ::grpc::ByteBuffer request_buf;
408     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf);
409     if (!s.ok()) {
410       Status status = FromGrpcStatus(s);
411       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
412                  << status.ToString();
413       done(status);
414       return true;
415     }
416 
417     mutex_lock l(mu_);
418     if (call_state_ != State::kActive) {
419       // `done` is not invoked intentionally.
420       return false;
421     }
422     if (VLOG_IS_ON(3)) {
423       // If vlog 3 is enabled, include first 100 chars of request as debug
424       // string.
425       exchanges_.Emplace(request_buf, response, done,
426                          request.ShortDebugString().substr(0, 100));
427     } else {
428       exchanges_.Emplace(request_buf, response, done, "");
429     }
430     MaybeIssueRequestWriteLocked();
431     return true;
432   }
433 
CallStarted(bool ok)434   void CallStarted(bool ok) override {
435     VLOG(3) << "StreamingRPCState(" << this << ")::CallStarted(ok=" << ok
436             << ")";
437     mutex_lock l(mu_);
438     if (!ok) {
439       call_state_ = State::kDone;
440       return;
441     }
442     exchanges_.CallStarted();
443     // Now that the call has started, we can write our first request, if any.
444     MaybeIssueRequestWriteLocked();
445   }
446 
RequestWriteCompleted(bool ok)447   void RequestWriteCompleted(bool ok) override {
448     VLOG(3) << "StreamingRPCState(" << this
449             << ")::RequestWriteCompleted(ok=" << ok << ")";
450     mu_.lock();
451     if (call_state_ != State::kActive) {
452       mu_.unlock();
453       return;
454     }
455     if (!ok) {
456       // unlocks mu_
457       MarkDoneAndCompleteExchanges(errors::Internal(
458           "Not ok value returned by CompletionQueue when attempting streaming "
459           "rpc write. Probably because the completion queue has been shut "
460           "down or the connection went down. ",
461           context_->debug_error_string()));
462       return;
463     }
464 
465     exchanges_.MarkRequestWriteCompleted();
466     MaybeIssueResponseReadLocked();
467     MaybeIssueRequestWriteLocked();
468     mu_.unlock();
469   }
470 
ResponseReadCompleted(bool ok)471   void ResponseReadCompleted(bool ok) override {
472     VLOG(3) << "StreamingRPCState(" << this
473             << ")::ResponseReadCompleted(ok=" << ok << ")";
474     mu_.lock();
475     if (call_state_ != State::kActive) {
476       mu_.unlock();
477       return;
478     }
479     if (!ok) {
480       IssueCallFinishLocked();
481       mu_.unlock();
482       return;
483     }
484 
485     // Complete the exchange without holding the lock because user's
486     // callback can call back into this RPC code resulting in a deadlock.
487     // No other thread can pop this exchange while we release the lock because
488     // this is the only method that pops exchanges and it is called from a
489     // single thread that waits on completion queue events.
490     Exchange* e;
491     e = &exchanges_.GetFront();
492     mu_.unlock();
493 
494     e->Complete(Status::OK());
495 
496     {
497       mutex_lock l(mu_);
498       exchanges_.PopFront();
499       MaybeIssueResponseReadLocked();
500     }
501   }
502 
CallFinished(bool ok)503   void CallFinished(bool ok) override {
504     VLOG(3) << "StreamingRPCState(" << this << ")::CallFinished(ok=" << ok
505             << ")";
506     mu_.lock();
507     DCHECK(call_state_ != State::kActive);
508     if (call_state_ != State::kFinishing) {
509       mu_.unlock();
510       return;
511     }
512 
513     Status s = FromGrpcStatus(call_status_);
514     if (s.ok() && !ok) {
515       s.Update(
516           errors::Internal("GRPC status is okay but CompletionQueueStatus is "
517                            "not.  This should never happen.",
518                            context_->debug_error_string()));
519     }
520     // unlocks mu_
521     MarkDoneAndCompleteExchanges(s);
522   }
523 
DebugString()524   string DebugString() const override {
525     mutex_lock l(mu_);
526     return exchanges_.DebugString();
527   }
528 
529  private:
530   enum class State {
531     kActive,
532     kFinishing,
533     kDone,
534   };
535 
MarkDoneAndCompleteExchanges(Status status)536   void MarkDoneAndCompleteExchanges(Status status) EXCLUSIVE_LOCKS_REQUIRED(mu_)
537       UNLOCK_FUNCTION(mu_) {
538     call_state_ = State::kDone;
539     VLOG(2) << "Ending gRPC stremaing call on the client side due to "
540             << status.ToString();
541     // Swap the exchanges_ into a temporary ExchangeQueue so that we can
542     // complete all exchanges without holding mu_ in case user callback
543     // reach back into this. This should be impossible now, but safer for
544     // the future.
545     ExchangeQueue queue;
546     exchanges_.Swap(&queue);
547     mu_.unlock();
548     queue.CompleteAll(status);
549   }
550 
MaybeIssueRequestWriteLocked()551   void MaybeIssueRequestWriteLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
552     Exchange* exchange = exchanges_.GetReadyForRequestWriting();
553     if (exchange == nullptr) {
554       // There are no queued exchanges, there is already an outstanding write,
555       // or there are no just created exchanges.
556       return;
557     }
558     exchange->MarkRequestWriteIssued();
559     Ref();
560     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Write";
561     call_->Write(exchange->request_buf(), &request_write_completed_tag_);
562   }
563 
MaybeIssueResponseReadLocked()564   void MaybeIssueResponseReadLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
565     Exchange* exchange = exchanges_.GetReadyForResponseReading();
566     if (exchange == nullptr) {
567       return;
568     }
569     exchange->MarkResponseReadIssued();
570     Ref();
571     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Read";
572     call_->Read(exchange->response_buf(), &response_read_completed_tag_);
573   }
574 
IssueCallFinishLocked()575   void IssueCallFinishLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
576     call_state_ = State::kFinishing;
577     Ref();
578     VLOG(3) << "StreamingRPCState(" << this << ") calling grpc::Finish";
579     // We call finish in response to completed (with error) response reading tag
580     // on some exchange. We let this exchange hang in ResponseReadIssued state.
581     // ExchangeQueue makes sure that there is at most one exchange in this
582     // state. So, no new reads will be issued.
583     call_->Finish(&call_status_, &finished_tag_);
584   }
585 
586   // Holds state for a single request/response exchange between the client
587   // and the server.
588   typedef typename UntypedStreamingRPCState::Tag Tag;
589 
590   // Order of context_ and call_ is important because context_ must outlive
591   // call_.
592   const std::shared_ptr<const ::grpc::ClientContext> context_;
593   std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_;
594 
595   mutable mutex mu_;
596   ExchangeQueue exchanges_ GUARDED_BY(mu_);
597   State call_state_ GUARDED_BY(mu_);
598   ::grpc::Status call_status_ GUARDED_BY(mu_);
599 
600   // We can get away with having single instances of these tags per
601   // StreamingRPCState because we make sure (as gRPC requires) that
602   // there is at most one outstanding Read and at most one outstanding Write
603   // in the completion queue.
604   // Tags are immutable. No need to guard them.
605   Tag call_started_tag_{this, Tag::TagType::kCallStarted};
606   Tag request_write_completed_tag_{this, Tag::TagType::kRequestWriteCompleted};
607   Tag response_read_completed_tag_{this, Tag::TagType::kResponseReadCommpleted};
608   Tag finished_tag_{this, Tag::TagType::kCallFinished};
609 };
610 
611 // Creates streaming calls and dispatches requests to them.
612 // In the common case, the client would create a StreamingRPCDispatcher for
613 // each bidirectional streaming RPC it might want to make. The first time, it
614 // calls SendNextRequest, a streaming call is initiated and the request is
615 // sent within this call. Initiation of the call blocks the client. If there are
616 // no errors, subsequent calls to SendNextRequest would use the already active
617 // call. If there was an error, the call object will be destroyed after all
618 // the callbacks for outstanding requests have been invoked. The next call to
619 // SendNextRequest will initiate a new call.
620 //
621 // Callbacks that are part of the same call, are invoked in the order they were
622 // provided, but callbacks across calls (a failed and a new one) can be invoked
623 // in any order.
624 //
625 // Thread-safe.
626 template <class Response>
627 class StreamingRPCDispatcher {
628  public:
StreamingRPCDispatcher(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method)629   StreamingRPCDispatcher(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
630                          const ::grpc::string& method)
631       : stub_(stub), cq_(cq), method_(method) {}
632 
633   // Attempts to send the next request. If there is no active streaming call,
634   // starts one and sends the request on top of it. `done` is invoked when
635   // `response` has been filled with the data from the server, or if there
636   // is an error. `done` can be invoked before SendNextRequest returns.
SendNextRequest(const protobuf::Message & request,Response * response,StatusCallback done)637   void SendNextRequest(const protobuf::Message& request, Response* response,
638                        StatusCallback done) {
639     mutex_lock l(mu_);
640     if (state_ == nullptr) {
641       CreateStreamingState();
642     }
643 
644     bool is_call_alive = state_->SendNextRequest(request, response, done);
645     if (is_call_alive) {
646       return;
647     }
648 
649     // The attempt to send failed because the call was dead, create a new
650     // call and try again. When the call is dead SendNextRequest does not call
651     // `done`.
652     CreateStreamingState();
653 
654     is_call_alive = state_->SendNextRequest(request, response, done);
655     if (!is_call_alive) {
656       // Consider retrying to create and start a call few more times.
657       done(errors::Unknown("gRPC call failed right after it was created"));
658     }
659   }
660 
661   // Request to cancel the current streaming call. Non-blocking.
CancelCall()662   void CancelCall() {
663     mutex_lock l(mu_);
664     if (state_ == nullptr) {
665       return;
666     }
667     context_->TryCancel();
668     state_ = nullptr;
669   }
670 
671  private:
CreateStreamingState()672   void CreateStreamingState() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
673     // ClientContext cannot be reused across calls.
674     context_ = std::make_shared<::grpc::ClientContext>();
675     // Don't immediately fail StartCall if the channel is not ready. Wait for
676     // the channel to become ready.
677     context_->set_wait_for_ready(true);
678 
679     std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call =
680         stub_->PrepareCall(context_.get(), method_, cq_);
681 
682     state_.reset(new StreamingRPCState<Response>(std::move(call), context_));
683   }
684 
685   mutable mutex mu_;
686 
687   // Both are thread-safe
688   ::grpc::GenericStub* const stub_;
689   ::grpc::CompletionQueue* const cq_;
690 
691   // Does not need synchronization since it is constant.
692   const ::grpc::string method_;
693 
694   std::shared_ptr<::grpc::ClientContext> context_ GUARDED_BY(mu_);
695   core::RefCountPtr<StreamingRPCState<Response>> state_ GUARDED_BY(mu_);
696 };
697 
698 }  // namespace tensorflow
699 
700 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
701