1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <vector>
19
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/refcount.h"
26 #include "tensorflow/core/util/rpc/call_container.h"
27 #include "tensorflow/core/util/rpc/rpc_factory.h"
28
29 #include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
30
31 namespace tensorflow {
32
33 namespace internal {
34 class GrpcCall {
35 public:
GrpcCall(CallContainer<GrpcCall> * container,int index,bool try_rpc,const string * request_msg,string * response_msg,int32 * status_code,string * status_message)36 explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
37 const string* request_msg, string* response_msg,
38 int32* status_code, string* status_message)
39 : container_(container),
40 index_(index),
41 try_rpc_(try_rpc),
42 request_msg_(request_msg),
43 response_msg_(response_msg),
44 status_code_(status_code),
45 status_message_(status_message) {}
46
StartCancel()47 void StartCancel() { call_opts_.StartCancel(); }
48
Done(const Status & s)49 void Done(const Status& s) {
50 DCHECK(container_ != nullptr);
51 if (!s.ok() && try_rpc_) {
52 DCHECK(status_code_ != nullptr);
53 DCHECK(status_message_ != nullptr);
54 *status_code_ = s.code();
55 *status_message_ = s.error_message();
56 }
57 container_->Done(s, index_);
58 }
59
call_opts()60 CallOptions* call_opts() { return &call_opts_; }
index()61 int index() { return index_; }
request() const62 const string& request() const { return *request_msg_; }
response() const63 string* response() const { return response_msg_; }
64
65 private:
66 CallContainer<GrpcCall>* const container_;
67 const int index_;
68 bool try_rpc_;
69 CallOptions call_opts_;
70 const string* request_msg_;
71 string* response_msg_;
72 int* status_code_;
73 string* status_message_;
74 };
75
76 } // namespace internal
77
78 using internal::GrpcCall;
79
GrpcRPCFactory(OpKernelConstruction * ctx,bool fail_fast,int64 timeout_in_ms)80 GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
81 int64 timeout_in_ms)
82 : RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
83 // TODO(ebrevdo): Investigate possible performance improvements by
84 // replacing this thread with a threadpool.
85 polling_thread_ =
86 ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
87 void* tag;
88 bool ok;
89 while (completion_queue_.Next(&tag, &ok)) {
90 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
91 callback_tag->OnCompleted(ok);
92 }
93 });
94 }
95
~GrpcRPCFactory()96 GrpcRPCFactory::~GrpcRPCFactory() {
97 // The amount of time we wait depends on several parameters, including:
98 // - the value of the fail_fast attribute.
99 // - the timeout option of the rpc call in the proto declaration.
100 // - the network roundtrip time and service's execution time.
101 //
102 // If a connection is made but the service doesn't ever respond, and
103 // there is no timeout option set for this rpc call, then it is
104 // possible the RPC request will wait forever.
105 //
106 completion_queue_.Shutdown();
107 delete polling_thread_;
108 }
109
Call(OpKernelContext * ctx,int64 num_elements,const Tensor & address_t,const Tensor & method_t,const Tensor & request_t,const bool try_rpc,Tensor * response_t,Tensor * status_code_t,Tensor * status_message_t,AsyncOpKernel::DoneCallback done)110 void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
111 const Tensor& address_t, const Tensor& method_t,
112 const Tensor& request_t, const bool try_rpc,
113 Tensor* response_t, Tensor* status_code_t,
114 Tensor* status_message_t,
115 AsyncOpKernel::DoneCallback done) {
116 if (try_rpc) {
117 // In this case status_code will never be set in the response,
118 // so we just set it to OK.
119 DCHECK(status_code_t != nullptr);
120 status_code_t->flat<int32>().setConstant(
121 static_cast<int>(errors::Code::OK));
122 }
123
124 CallContainer<GrpcCall>::CreateCallFn create_call_fn =
125 [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
126 CallContainer<GrpcCall>* container, int index) {
127 CreateCall(request_t, try_rpc, index, container, response_t,
128 status_code_t, status_message_t);
129 };
130
131 CallContainer<GrpcCall>::StartCallFn start_call_fn =
132 [this, &address_t, &method_t](GrpcCall* call) {
133 StartCall(address_t, method_t, call);
134 };
135
136 // This object will delete itself when done.
137 new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
138 std::move(done), std::move(create_call_fn),
139 std::move(start_call_fn));
140 }
141
GetOrCreateStubForAddress(const string & address)142 ::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
143 const string& address) {
144 mutex_lock lock(mu_);
145
146 auto stub = stubs_.find(address);
147 if (stub != stubs_.end()) return stub->second.get();
148
149 ChannelPtr channel = CreateChannelForAddress(address);
150 auto* created = new ::grpc::GenericStub(channel);
151 stubs_[address].reset(created);
152 return created;
153 }
154
CreateChannelForAddress(const string & address)155 GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
156 const string& address) {
157 ::grpc::ChannelArguments args;
158 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
159
160 // Set a standard backoff timeout of 1s instead of the
161 // (sometimes default) 20s.
162 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
163 return ::grpc::CreateCustomChannel(
164 /*target=*/address, ::grpc::InsecureChannelCredentials(), args);
165 }
166
CreateCall(const Tensor & request_t,const bool try_rpc,int index,CallContainer<GrpcCall> * container,Tensor * response_t,Tensor * status_code_t,Tensor * status_message_t)167 void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
168 int index, CallContainer<GrpcCall>* container,
169 Tensor* response_t, Tensor* status_code_t,
170 Tensor* status_message_t) {
171 auto request = request_t.flat<string>();
172 auto get_request_ptr = [&request](int64 ix) -> const string* {
173 return (request.size() > 1) ? &(request(ix)) : &(request(0));
174 };
175 auto response = response_t->flat<string>();
176 int32* status_code_ptr = nullptr;
177 string* status_message_ptr = nullptr;
178 if (try_rpc) {
179 status_code_ptr = status_code_t->flat<int32>().data();
180 status_message_ptr = status_message_t->flat<string>().data();
181 }
182 container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
183 &response(index),
184 (try_rpc) ? &status_code_ptr[index] : nullptr,
185 (try_rpc) ? &status_message_ptr[index] : nullptr);
186 }
187
StartCall(const Tensor & address_t,const Tensor & method_t,GrpcCall * call)188 void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
189 GrpcCall* call) {
190 auto address = address_t.flat<string>();
191 auto method = method_t.flat<string>();
192 // Stubs are maintained by the GrpcRPCFactory class and will be
193 // deleted when the class is destroyed.
194 ::grpc::GenericStub* singleton_stub = nullptr;
195 if (address.size() == 1) {
196 singleton_stub = GetOrCreateStubForAddress(address(0));
197 }
198 auto get_stub = [&address, this,
199 singleton_stub](int64 ix) -> ::grpc::GenericStub* {
200 return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
201 : singleton_stub;
202 };
203 auto get_method_ptr = [&method](int64 ix) -> const string* {
204 return (method.size() > 1) ? &(method(ix)) : &(method(0));
205 };
206
207 int index = call->index();
208 // This object will delete itself when done.
209 new RPCState<string>(
210 get_stub(index), &completion_queue_, *get_method_ptr(index),
211 call->request(), call->response(),
212 /*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
213 nullptr /*threadpool*/, fail_fast_, timeout_in_ms_, 0 /* max_retries */);
214 }
215
216 } // namespace tensorflow
217