1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 18 19 #include <utility> 20 21 #include "grpcpp/generic/generic_stub.h" 22 #include "grpcpp/grpcpp.h" 23 24 #include "tensorflow/core/distributed_runtime/call_options.h" 25 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h" 26 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" 27 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 28 #include "tensorflow/core/lib/core/threadpool.h" 29 #include "tensorflow/core/lib/strings/strcat.h" 30 #include "tensorflow/core/platform/notification.h" 31 32 namespace tensorflow { 33 34 // Object allocated per active RPC. 35 // Manage the state of a single asynchronous RPC request. If `max_retries` 36 // is greater than 0, the request will be retried for any transient failures 37 // as long as the overall deadline has not elapsed. 38 template <class Response> 39 class RPCState : public GrpcClientCQTag { 40 public: 41 // Default behavior is to set fail_fast = False and handle timeouts manually. 42 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 43 const ::grpc::string& method, const protobuf::Message& request, 44 Response* response, StatusCallback done, CallOptions* call_opts, 45 thread::ThreadPool* threadpool, int32 max_retries = 0) 46 : RPCState(stub, cq, method, request, response, std::move(done), 47 call_opts, threadpool, /*fail_fast=*/false, 48 /*timeout_in_ms=*/0, max_retries) {} 49 50 template <typename Request> RPCState(::grpc::GenericStub * stub,::grpc::CompletionQueue * cq,const::grpc::string & method,const Request & request,Response * response,StatusCallback done,CallOptions * call_opts,thread::ThreadPool * threadpool,bool fail_fast,int64 timeout_in_ms,int32 max_retries)51 RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq, 52 const ::grpc::string& method, const Request& request, 53 Response* response, StatusCallback done, CallOptions* call_opts, 54 thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms, 55 int32 max_retries) 56 : call_opts_(call_opts), 57 threadpool_(threadpool), 58 done_(std::move(done)), 59 cq_(cq), 60 stub_(stub), 61 method_(method), 62 max_retries_(max_retries), 63 timeout_in_ms_(timeout_in_ms), 64 fail_fast_(fail_fast) { 65 response_ = response; 66 ::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_); 67 if (!s.ok()) { 68 LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: " 69 << s.error_message(); 70 // Skip retry logic if we fail to parse our request. 71 done_(FromGrpcStatus(s)); 72 delete this; 73 return; 74 } 75 StartCall(); 76 } 77 StartCall()78 void StartCall() { 79 context_.reset(new ::grpc::ClientContext()); 80 context_->set_fail_fast(fail_fast_); 81 82 if (timeout_in_ms_ > 0) { 83 context_->set_deadline( 84 gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN)); 85 } 86 if (call_opts_) { 87 call_opts_->SetCancelCallback([this]() { context_->TryCancel(); }); 88 } 89 90 VLOG(2) << "Starting call: " << method_; 91 92 call_ = std::move( 93 stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_)); 94 call_->StartCall(); 95 call_->Finish(&response_buf_, &status_, this); 96 } 97 OnCompleted(bool ok)98 void OnCompleted(bool ok) override { 99 if (call_opts_) { 100 call_opts_->ClearCancelCallback(); 101 } 102 Status s = FromGrpcStatus(status_); 103 if (s.ok() && !ok) { 104 // Since this function is only being used for processing the response 105 // to Finish for client-side unary calls, ok should never be false 106 s.Update(errors::Internal("unexpected ok value at rpc completion")); 107 } 108 109 if (s.ok()) { 110 if (threadpool_) { 111 // Run parse and callback in another thread, returning this 112 // one to service more RPCs. 113 threadpool_->Schedule([this]() { ParseAndCallDone(); }); 114 } else { 115 ParseAndCallDone(); 116 } 117 return; 118 } 119 120 VLOG(1) << method_ << " returned with non-ok status: " << s 121 << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n" 122 << context_->debug_error_string(); 123 // Retry if we have any attempts left 124 if (++num_retries_ <= max_retries_ && 125 (errors::IsUnavailable(s) || errors::IsUnknown(s))) { 126 response_buf_.Clear(); 127 VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_ 128 << " of " << max_retries_; 129 StartCall(); 130 } else { 131 // Attach additional GRPC error information if any to the final status 132 s = Status(s.code(), 133 strings::StrCat(s.error_message(), 134 "\nAdditional GRPC error information:\n", 135 context_->debug_error_string())); 136 done_(s); 137 delete this; 138 } 139 } 140 ParseAndCallDone()141 void ParseAndCallDone() { 142 Status s; 143 if (!GrpcMaybeParseProto(&response_buf_, response_)) { 144 s.Update(errors::Internal("could not parse rpc response")); 145 } 146 done_(s); 147 delete this; 148 } 149 150 private: 151 CallOptions* call_opts_; 152 std::unique_ptr<::grpc::ClientContext> context_; 153 thread::ThreadPool* threadpool_; 154 std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_; 155 Response* response_; 156 ::grpc::ByteBuffer request_buf_; 157 ::grpc::ByteBuffer response_buf_; 158 ::grpc::Status status_; 159 StatusCallback done_; 160 int64 timeout_in_ms_; 161 162 size_t num_retries_ = 0; 163 size_t max_retries_; 164 165 ::grpc::CompletionQueue* cq_; 166 ::grpc::GenericStub* stub_; 167 ::grpc::string method_; 168 bool fail_fast_; 169 }; 170 171 } // namespace tensorflow 172 173 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_STATE_H_ 174