1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/call_options.h"
20 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/protobuf/eager_service.pb.h"
28 #include "tensorflow/core/util/env_var.h"
29
30 namespace tensorflow {
31 namespace eager {
32 namespace {
33
34 /*
35 * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
36 * true will turn on asynchronous execution of remote op. It means that when
37 * executing an op on a remote worker, client will not block on waiting
38 * for the response anymore. Using follow code as example:
39 *
40 * with tf.device('worker:0'):
41 * a = tf.matmul(...)
42 * b = tf.matmul(...)
43 * logging.into('Requests sent') # Probably not executed yet
44 * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
45 *
46 * Streaming RPC will preserve order as well. So 'a' must be executed before
47 * 'b' on 'worker:0'.
48 *
49 * When turning on this feature, you should explicitly wait for some result
50 * from remote workers at the end of you python program. Otherwise, client may
51 * shutdown remote workers without waiting all pending ops.
52 *
53 * TODO(fishx): When exiting client, make sure all pending ops on remote workers
54 * are finished.
55 *
56 * TODO(b/139210648): Move this comment to eager/execute.py when this feature is
57 * on by default.
58 */
EnableStreaming()59 bool EnableStreaming() {
60 bool result;
61 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
62 true, &result));
63 return result;
64 }
65
66 // Ref-counted thread to handle callbacks for completed requests a GRPC
67 // completion queue. The thread might be shared by multiple eager clients, and
68 // each one of them should hold a reference count to ensure that the thread
69 // outlives the clients.
70 // To ensure that every tag in completion queue is processed, this thread also
71 // holds a reference to itself and always wait until ref count is one to exit.
72 class GrpcEagerClientThread : public core::RefCounted {
73 public:
GrpcEagerClientThread()74 GrpcEagerClientThread() {
75 // Hold a reference to ensure every completion tag gets processed.
76 Ref();
77 thread_.reset(Env::Default()->StartThread(
78 ThreadOptions(), "eager_client_thread", [this]() {
79 void* tag;
80 bool ok;
81 while (completion_queue_.Next(&tag, &ok)) {
82 VLOG(4) << "GrpcEagerClientThread got next tag";
83 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
84 callback_tag->OnCompleted(ok);
85 VLOG(4) << "GrpcEagerClientThread blocking for next tag";
86 if (RefCountIsOne()) {
87 break;
88 }
89 }
90 VLOG(4) << "GrpcEagerClientThread exiting";
91 completion_queue_.Shutdown();
92 // `this` holds the final reference so cannot directly Unref here.
93 // Instead, schedule a separate thread to clean it up.
94 Env::Default()->SchedClosure([this]() { this->Unref(); });
95 }));
96 }
97
~GrpcEagerClientThread()98 ~GrpcEagerClientThread() override {}
99
completion_queue()100 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
101
102 private:
103 ::grpc::CompletionQueue completion_queue_;
104 std::unique_ptr<Thread> thread_;
105 };
106
107 class GrpcEagerClient : public EagerClient {
108 public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread,const string & target)109 GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
110 GrpcEagerClientThread* thread, const string& target)
111 : stub_(channel), thread_(thread), target_(target) {
112 // Hold a reference to make sure the corresponding EagerClientThread
113 // outlives the client.
114 thread_->Ref();
115 cq_ = thread->completion_queue();
116 }
~GrpcEagerClient()117 ~GrpcEagerClient() override { thread_->Unref(); }
118
allow_multiple_pending_requests() const119 bool allow_multiple_pending_requests() const override {
120 return EnableStreaming();
121 }
122
123 #define CLIENT_METHOD(method) \
124 void method##Async(const method##Request* request, \
125 method##Response* response, StatusCallback done) \
126 override { \
127 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
128 new RPCState<protobuf::Message>( \
129 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
130 response, std::move(done_wrapped), /*call_opts=*/nullptr, \
131 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \
132 &target_); \
133 }
134
135 CLIENT_METHOD(CreateContext);
136 CLIENT_METHOD(UpdateContext);
137 CLIENT_METHOD(WaitQueueDone);
138 CLIENT_METHOD(KeepAlive);
139
140 #undef CLIENT_METHOD
141
142 #define CLIENT_CANCELABLE_METHOD(method) \
143 void method##Async(CallOptions* call_opts, const method##Request* request, \
144 method##Response* response, StatusCallback done) \
145 override { \
146 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
147 new RPCState<protobuf::Message>( \
148 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
149 response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
150 /*max_retries=*/0, /*fail_fast=*/true, &target_); \
151 }
152
153 CLIENT_CANCELABLE_METHOD(Enqueue);
154 CLIENT_CANCELABLE_METHOD(RunComponentFunction);
155
156 #undef CLIENT_CANCELABLE_METHOD
157
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)158 void CloseContextAsync(const CloseContextRequest* request,
159 CloseContextResponse* response,
160 StatusCallback done) override {
161 StatusCallback done_wrapped = callback_wrapper(std::move(done));
162 new RPCState<protobuf::Message>(
163 &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
164 response, std::move(done_wrapped), /*call_opts=*/nullptr,
165 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
166 &target_);
167
168 VLOG(1) << "Sending RPC to close remote eager context "
169 << request->DebugString();
170
171 mutex_lock l(mu_);
172 const auto& it = enqueue_dispatchers_.find(request->context_id());
173 if (it != enqueue_dispatchers_.end()) {
174 it->second.CancelCall();
175 enqueue_dispatchers_.erase(it);
176 } else if (EnableStreaming()) {
177 LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
178 << " does not seem to exist.";
179 }
180 }
181
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)182 void StreamingEnqueueAsync(CallOptions* call_opts,
183 const EnqueueRequest* request,
184 EnqueueResponse* response,
185 StatusCallback done) override {
186 StatusCallback done_wrapped = callback_wrapper(std::move(done));
187 if (EnableStreaming()) {
188 mutex_lock l(mu_);
189 auto it = enqueue_dispatchers_.find(request->context_id());
190 if (it == enqueue_dispatchers_.end()) {
191 auto it_and_bool = enqueue_dispatchers_.emplace(
192 std::piecewise_construct,
193 std::forward_as_tuple(request->context_id()),
194 std::forward_as_tuple(
195 &stub_, cq_,
196 "/tensorflow.eager.EagerService/StreamingEnqueue"));
197 it = it_and_bool.first;
198 }
199 // TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
200 it->second.SendNextRequest(*request, response, std::move(done_wrapped));
201 } else {
202 Notification n;
203 Status status;
204 EnqueueAsync(call_opts, request, response,
205 [&n, &status](const Status& s) {
206 status.Update(s);
207 n.Notify();
208 });
209 n.WaitForNotification();
210 done_wrapped(status);
211 }
212 }
213
214 private:
215 ::grpc::GenericStub stub_;
216 const GrpcEagerClientThread* thread_;
217 const string target_;
218
219 ::grpc::CompletionQueue* cq_;
220
221 mutable mutex mu_;
222
223 std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
224 enqueue_dispatchers_ TF_GUARDED_BY(mu_);
225
callback_wrapper(StatusCallback done)226 StatusCallback callback_wrapper(StatusCallback done) {
227 Ref();
228 return [this, done = std::move(done)](const Status& status) {
229 done(status);
230 this->Unref();
231 };
232 }
233 };
234
235 class GrpcEagerClientCache : public EagerClientCache {
236 public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)237 explicit GrpcEagerClientCache(
238 std::shared_ptr<tensorflow::GrpcChannelCache> cache)
239 : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
240 for (int i = 0, end = threads_.size(); i < end; i++) {
241 threads_[i].reset(new GrpcEagerClientThread());
242 }
243 }
244
~GrpcEagerClientCache()245 ~GrpcEagerClientCache() override { threads_.clear(); }
246
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)247 Status GetClient(const string& target,
248 core::RefCountPtr<EagerClient>* client) override {
249 mutex_lock l(clients_mu_);
250 auto it = clients_.find(target);
251 if (it == clients_.end()) {
252 tensorflow::SharedGrpcChannelPtr shared =
253 cache_->FindWorkerChannel(target);
254 if (shared == nullptr) {
255 return errors::InvalidArgument("Client for target ", target,
256 " not found.");
257 }
258 int assigned_index = AssignClientToThread(target);
259 GrpcEagerClientThread* thread = threads_[assigned_index].get();
260 core::RefCountPtr<EagerClient> worker(
261 new GrpcEagerClient(shared, thread, target));
262 it = clients_.emplace(target, std::move(worker)).first;
263 }
264
265 it->second->Ref();
266 client->reset(it->second.get());
267 return Status::OK();
268 }
269
270 private:
271 mutex assignment_mu_;
272 std::unordered_map<std::string, size_t> target_assignments_
273 TF_GUARDED_BY(assignment_mu_);
274 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
275
AssignClientToThread(const string & target)276 size_t AssignClientToThread(const string& target) {
277 // Round-robin target assignment, but keeps the same target on the same
278 // polling thread always, as this is important for gRPC performance
279 mutex_lock lock(assignment_mu_);
280 auto it = target_assignments_.find(target);
281 if (it == target_assignments_.end()) {
282 it = target_assignments_
283 .insert(std::make_pair(
284 target, (next_round_robin_assignment_++) % threads_.size()))
285 .first;
286 }
287 return it->second;
288 }
289
290 std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
291 mutable mutex clients_mu_;
292 std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
293 TF_GUARDED_BY(clients_mu_);
294 std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
295 };
296
297 } // namespace
298
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)299 EagerClientCache* NewGrpcEagerClientCache(
300 std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
301 return new GrpcEagerClientCache(channel);
302 }
303
304 } // namespace eager
305 } // namespace tensorflow
306