• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
18 
19 #include <utility>
20 
21 #include "grpcpp/generic/generic_stub.h"
22 #include "grpcpp/grpcpp.h"
23 
24 #include "tensorflow/core/distributed_runtime/call_options.h"
25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
27 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
28 #include "tensorflow/core/lib/core/threadpool.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/notification.h"
31 
32 namespace tensorflow {
33 
34 // Object allocated per active RPC.
35 // Manage the state of a single asynchronous RPC request.  If `max_retries`
36 // is greater than 0, the request will be retried for any transient failures
37 // as long as the overall deadline has not elapsed.
38 template <class Response>
39 class RPCState : public GrpcClientCQTag {
40  public:
41   // Default behavior is to set fail_fast = False and handle timeouts manually.
42   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
43            const ::grpc::string& method, const protobuf::Message& request,
44            Response* response, StatusCallback done, CallOptions* call_opts,
45            thread::ThreadPool* threadpool, int32 max_retries = 0)
46       : RPCState(stub, cq, method, request, response, std::move(done),
47                  call_opts, threadpool, /*fail_fast=*/false,
48                  /*timeout_in_ms=*/0, max_retries) {}
49 
50   template <typename Request>
RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64 timeout_in_ms,int32 max_retries)51   RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
52            const ::grpc::string& method, const Request& request,
53            Response* response, StatusCallback done, CallOptions* call_opts,
54            thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms,
55            int32 max_retries)
56       : call_opts_(call_opts),
57         threadpool_(threadpool),
58         done_(std::move(done)),
59         cq_(cq),
60         stub_(stub),
61         method_(method),
62         max_retries_(max_retries),
63         timeout_in_ms_(timeout_in_ms),
64         fail_fast_(fail_fast) {
65     response_ = response;
66     ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
67     if (!s.ok()) {
68       LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
69                  << s.error_message();
70       // Skip retry logic if we fail to parse our request.
71       done_(FromGrpcStatus(s));
72       delete this;
73       return;
74     }
75     StartCall();
76   }
77 
StartCall()78   void StartCall() {
79     context_.reset(new ::grpc::ClientContext());
80     context_->set_fail_fast(fail_fast_);
81 
82     if (timeout_in_ms_ > 0) {
83       context_->set_deadline(
84           gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
85     }
86     if (call_opts_) {
87       call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
88     }
89 
90     VLOG(2) << "Starting call: " << method_;
91 
92     call_ = std::move(
93         stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_));
94     call_->StartCall();
95     call_->Finish(&response_buf_, &status_, this);
96   }
97 
OnCompleted(bool ok)98   void OnCompleted(bool ok) override {
99     if (call_opts_) {
100       call_opts_->ClearCancelCallback();
101     }
102     Status s = FromGrpcStatus(status_);
103     if (s.ok() && !ok) {
104       // Since this function is only being used for processing the response
105       // to Finish for client-side unary calls, ok should never be false
106       s.Update(errors::Internal("unexpected ok value at rpc completion"));
107     }
108 
109     if (s.ok()) {
110       if (threadpool_) {
111         // Run parse and callback in another thread, returning this
112         // one to service more RPCs.
113         threadpool_->Schedule([this]() { ParseAndCallDone(); });
114       } else {
115         ParseAndCallDone();
116       }
117       return;
118     }
119 
120     VLOG(1) << method_ << " returned with non-ok status: " << s
121             << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n"
122             << context_->debug_error_string();
123     // Retry if we have any attempts left
124     if (++num_retries_ <= max_retries_ &&
125         (errors::IsUnavailable(s) || errors::IsUnknown(s))) {
126       response_buf_.Clear();
127       VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_
128               << " of " << max_retries_;
129       StartCall();
130     } else {
131       // Attach additional GRPC error information if any to the final status
132       s = Status(s.code(),
133                  strings::StrCat(s.error_message(),
134                                  "\nAdditional GRPC error information:\n",
135                                  context_->debug_error_string()));
136       done_(s);
137       delete this;
138     }
139   }
140 
ParseAndCallDone()141   void ParseAndCallDone() {
142     Status s;
143     if (!GrpcMaybeParseProto(&response_buf_, response_)) {
144       s.Update(errors::Internal("could not parse rpc response"));
145     }
146     done_(s);
147     delete this;
148   }
149 
150  private:
151   CallOptions* call_opts_;
152   std::unique_ptr<::grpc::ClientContext> context_;
153   thread::ThreadPool* threadpool_;
154   std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;
155   Response* response_;
156   ::grpc::ByteBuffer request_buf_;
157   ::grpc::ByteBuffer response_buf_;
158   ::grpc::Status status_;
159   StatusCallback done_;
160   int64 timeout_in_ms_;
161 
162   size_t num_retries_ = 0;
163   size_t max_retries_;
164 
165   ::grpc::CompletionQueue* cq_;
166   ::grpc::GenericStub* stub_;
167   ::grpc::string method_;
168   bool fail_fast_;
169 };
170 
171 }  // namespace tensorflow
172 
173 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_
174