• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/util/rpc/call_container.h"
27 #include "tensorflow/core/util/rpc/rpc_factory.h"
28 
29 #include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
30 
31 namespace tensorflow {
32 
33 namespace internal {
34 class GrpcCall {
35  public:
GrpcCall(CallContainer<GrpcCall> * container,int index,bool try_rpc,const string * request_msg,string * response_msg,int32 * status_code,string * status_message)36   explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
37                     const string* request_msg, string* response_msg,
38                     int32* status_code, string* status_message)
39       : container_(container),
40         index_(index),
41         try_rpc_(try_rpc),
42         request_msg_(request_msg),
43         response_msg_(response_msg),
44         status_code_(status_code),
45         status_message_(status_message) {}
46 
StartCancel()47   void StartCancel() { call_opts_.StartCancel(); }
48 
Done(const Status & s)49   void Done(const Status& s) {
50     DCHECK(container_ != nullptr);
51     if (!s.ok() && try_rpc_) {
52       DCHECK(status_code_ != nullptr);
53       DCHECK(status_message_ != nullptr);
54       *status_code_ = s.code();
55       *status_message_ = s.error_message();
56     }
57     container_->Done(s, index_);
58   }
59 
call_opts()60   CallOptions* call_opts() { return &call_opts_; }
index()61   int index() { return index_; }
request() const62   const string& request() const { return *request_msg_; }
response() const63   string* response() const { return response_msg_; }
64 
65  private:
66   CallContainer<GrpcCall>* const container_;
67   const int index_;
68   bool try_rpc_;
69   CallOptions call_opts_;
70   const string* request_msg_;
71   string* response_msg_;
72   int* status_code_;
73   string* status_message_;
74 };
75 
76 }  // namespace internal
77 
78 using internal::GrpcCall;
79 
GrpcRPCFactory(OpKernelConstruction * ctx,bool fail_fast,int64 timeout_in_ms)80 GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
81                                int64 timeout_in_ms)
82     : RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
83   // TODO(ebrevdo): Investigate possible performance improvements by
84   // replacing this thread with a threadpool.
85   polling_thread_ =
86       ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
87         void* tag;
88         bool ok;
89         while (completion_queue_.Next(&tag, &ok)) {
90           GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
91           callback_tag->OnCompleted(ok);
92         }
93       });
94 }
95 
~GrpcRPCFactory()96 GrpcRPCFactory::~GrpcRPCFactory() {
97   // The amount of time we wait depends on several parameters, including:
98   //   - the value of the fail_fast attribute.
99   //   - the timeout option of the rpc call in the proto declaration.
100   //   - the network roundtrip time and service's execution time.
101   //
102   // If a connection is made but the service doesn't ever respond, and
103   // there is no timeout option set for this rpc call, then it is
104   // possible the RPC request will wait forever.
105   //
106   completion_queue_.Shutdown();
107   delete polling_thread_;
108 }
109 
Call(OpKernelContext * ctx,int64 num_elements,const Tensor & address_t,const Tensor & method_t,const Tensor & request_t,const bool try_rpc,Tensor * response_t,Tensor * status_code_t,Tensor * status_message_t,AsyncOpKernel::DoneCallback done)110 void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
111                           const Tensor& address_t, const Tensor& method_t,
112                           const Tensor& request_t, const bool try_rpc,
113                           Tensor* response_t, Tensor* status_code_t,
114                           Tensor* status_message_t,
115                           AsyncOpKernel::DoneCallback done) {
116   if (try_rpc) {
117     // In this case status_code will never be set in the response,
118     // so we just set it to OK.
119     DCHECK(status_code_t != nullptr);
120     status_code_t->flat<int32>().setConstant(
121         static_cast<int>(errors::Code::OK));
122   }
123 
124   CallContainer<GrpcCall>::CreateCallFn create_call_fn =
125       [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
126           CallContainer<GrpcCall>* container, int index) {
127         CreateCall(request_t, try_rpc, index, container, response_t,
128                    status_code_t, status_message_t);
129       };
130 
131   CallContainer<GrpcCall>::StartCallFn start_call_fn =
132       [this, &address_t, &method_t](GrpcCall* call) {
133         StartCall(address_t, method_t, call);
134       };
135 
136   // This object will delete itself when done.
137   new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
138                               std::move(done), std::move(create_call_fn),
139                               std::move(start_call_fn));
140 }
141 
GetOrCreateStubForAddress(const string & address)142 ::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
143     const string& address) {
144   mutex_lock lock(mu_);
145 
146   auto stub = stubs_.find(address);
147   if (stub != stubs_.end()) return stub->second.get();
148 
149   ChannelPtr channel = CreateChannelForAddress(address);
150   auto* created = new ::grpc::GenericStub(channel);
151   stubs_[address].reset(created);
152   return created;
153 }
154 
CreateChannelForAddress(const string & address)155 GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
156     const string& address) {
157   ::grpc::ChannelArguments args;
158   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
159 
160   // Set a standard backoff timeout of 1s instead of the
161   // (sometimes default) 20s.
162   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
163   return ::grpc::CreateCustomChannel(
164       /*target=*/address, ::grpc::InsecureChannelCredentials(), args);
165 }
166 
CreateCall(const Tensor & request_t,const bool try_rpc,int index,CallContainer<GrpcCall> * container,Tensor * response_t,Tensor * status_code_t,Tensor * status_message_t)167 void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
168                                 int index, CallContainer<GrpcCall>* container,
169                                 Tensor* response_t, Tensor* status_code_t,
170                                 Tensor* status_message_t) {
171   auto request = request_t.flat<string>();
172   auto get_request_ptr = [&request](int64 ix) -> const string* {
173     return (request.size() > 1) ? &(request(ix)) : &(request(0));
174   };
175   auto response = response_t->flat<string>();
176   int32* status_code_ptr = nullptr;
177   string* status_message_ptr = nullptr;
178   if (try_rpc) {
179     status_code_ptr = status_code_t->flat<int32>().data();
180     status_message_ptr = status_message_t->flat<string>().data();
181   }
182   container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
183                           &response(index),
184                           (try_rpc) ? &status_code_ptr[index] : nullptr,
185                           (try_rpc) ? &status_message_ptr[index] : nullptr);
186 }
187 
StartCall(const Tensor & address_t,const Tensor & method_t,GrpcCall * call)188 void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
189                                GrpcCall* call) {
190   auto address = address_t.flat<string>();
191   auto method = method_t.flat<string>();
192   // Stubs are maintained by the GrpcRPCFactory class and will be
193   // deleted when the class is destroyed.
194   ::grpc::GenericStub* singleton_stub = nullptr;
195   if (address.size() == 1) {
196     singleton_stub = GetOrCreateStubForAddress(address(0));
197   }
198   auto get_stub = [&address, this,
199                    singleton_stub](int64 ix) -> ::grpc::GenericStub* {
200     return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
201                                 : singleton_stub;
202   };
203   auto get_method_ptr = [&method](int64 ix) -> const string* {
204     return (method.size() > 1) ? &(method(ix)) : &(method(0));
205   };
206 
207   int index = call->index();
208   // This object will delete itself when done.
209   new RPCState<string>(
210       get_stub(index), &completion_queue_, *get_method_ptr(index),
211       call->request(), call->response(),
212       /*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
213       nullptr /*threadpool*/, fail_fast_, timeout_in_ms_, 0 /* max_retries */);
214 }
215 
216 }  // namespace tensorflow
217