1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/call_options.h"
20 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/protobuf/eager_service.pb.h"
28 #include "tensorflow/core/util/env_var.h"
29
30 namespace tensorflow {
31 namespace eager {
32 namespace {
33
34 /*
35 * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
36 * true will turn on asynchronous execution of remote op. It means that when
37 * executing an op on a remote worker, client will not block on waiting
38 * for the response anymore. Using follow code as example:
39 *
40 * with tf.device('worker:0'):
41 * a = tf.matmul(...)
42 * b = tf.matmul(...)
43 * logging.into('Requests sent') # Probably not executed yet
44 * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
45 *
46 * Streaming RPC will preserve order as well. So 'a' must be executed before
47 * 'b' on 'worker:0'.
48 *
49 * When turning on this feature, you should explicitly wait for some result
50 * from remote workers at the end of you python program. Otherwise, client may
51 * shutdown remote workers without waiting all pending ops.
52 *
53 * TODO(fishx): When exiting client, make sure all pending ops on remote workers
54 * are finished.
55 */
EnableStreaming()56 bool EnableStreaming() {
57 bool result;
58 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
59 true, &result));
60 return result;
61 }
62
63 // Ref-counted thread to handle callbacks for completed requests a GRPC
64 // completion queue. The thread might be shared by multiple eager clients, and
65 // each one of them should hold a reference count to ensure that the thread
66 // outlives the clients.
67 // To ensure that every tag in completion queue is processed, this thread also
68 // holds a reference to itself and always wait until ref count is one to exit.
69 class GrpcEagerClientThread : public core::RefCounted {
70 public:
GrpcEagerClientThread()71 GrpcEagerClientThread() {
72 // Hold a reference to ensure every completion tag gets processed.
73 Ref();
74 thread_.reset(Env::Default()->StartThread(
75 ThreadOptions(), "eager_client_thread", [this]() {
76 void* tag;
77 bool ok;
78 while (completion_queue_.Next(&tag, &ok)) {
79 VLOG(4) << "GrpcEagerClientThread got next tag";
80 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
81 callback_tag->OnCompleted(ok);
82 VLOG(4) << "GrpcEagerClientThread blocking for next tag";
83 if (RefCountIsOne()) {
84 break;
85 }
86 }
87 VLOG(4) << "GrpcEagerClientThread exiting";
88 completion_queue_.Shutdown();
89 // `this` holds the final reference so cannot directly Unref here.
90 // Instead, schedule a separate thread to clean it up.
91 Env::Default()->SchedClosure([this]() { this->Unref(); });
92 }));
93 }
94
~GrpcEagerClientThread()95 ~GrpcEagerClientThread() override {}
96
completion_queue()97 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
98
99 private:
100 ::grpc::CompletionQueue completion_queue_;
101 std::unique_ptr<Thread> thread_;
102 };
103
104 class GrpcEagerClient : public EagerClient {
105 public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread,const string & target)106 GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
107 GrpcEagerClientThread* thread, const string& target)
108 : stub_(channel), thread_(thread), target_(target) {
109 // Hold a reference to make sure the corresponding EagerClientThread
110 // outlives the client.
111 thread_->Ref();
112 cq_ = thread->completion_queue();
113 }
~GrpcEagerClient()114 ~GrpcEagerClient() override { thread_->Unref(); }
115
allow_multiple_pending_requests() const116 bool allow_multiple_pending_requests() const override {
117 return EnableStreaming();
118 }
119
120 #define CLIENT_METHOD(method) \
121 void method##Async(const method##Request* request, \
122 method##Response* response, StatusCallback done) \
123 override { \
124 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
125 new RPCState<protobuf::Message>( \
126 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
127 response, std::move(done_wrapped), /*call_opts=*/nullptr, \
128 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true, \
129 &target_); \
130 }
131
132 CLIENT_METHOD(CreateContext);
133 CLIENT_METHOD(UpdateContext);
134 CLIENT_METHOD(WaitQueueDone);
135 CLIENT_METHOD(KeepAlive);
136
137 #undef CLIENT_METHOD
138
139 #define CLIENT_CANCELABLE_METHOD(method) \
140 void method##Async(CallOptions* call_opts, const method##Request* request, \
141 method##Response* response, StatusCallback done) \
142 override { \
143 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
144 new RPCState<protobuf::Message>( \
145 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
146 response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
147 /*max_retries=*/0, /*fail_fast=*/true, &target_); \
148 }
149
150 CLIENT_CANCELABLE_METHOD(Enqueue);
151 CLIENT_CANCELABLE_METHOD(RunComponentFunction);
152
153 #undef CLIENT_CANCELABLE_METHOD
154
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)155 void CloseContextAsync(const CloseContextRequest* request,
156 CloseContextResponse* response,
157 StatusCallback done) override {
158 StatusCallback done_wrapped = callback_wrapper(std::move(done));
159 new RPCState<protobuf::Message>(
160 &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
161 response, std::move(done_wrapped), /*call_opts=*/nullptr,
162 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
163 &target_);
164
165 VLOG(1) << "Sending RPC to close remote eager context "
166 << request->DebugString();
167
168 mutex_lock l(mu_);
169 const auto& it = enqueue_dispatchers_.find(request->context_id());
170 if (it != enqueue_dispatchers_.end()) {
171 it->second.CancelCall();
172 enqueue_dispatchers_.erase(it);
173 } else if (EnableStreaming()) {
174 LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
175 << " does not seem to exist.";
176 }
177 }
178
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)179 void StreamingEnqueueAsync(CallOptions* call_opts,
180 const EnqueueRequest* request,
181 EnqueueResponse* response,
182 StatusCallback done) override {
183 StatusCallback done_wrapped = callback_wrapper(std::move(done));
184 if (EnableStreaming()) {
185 mutex_lock l(mu_);
186 auto it = enqueue_dispatchers_.find(request->context_id());
187 if (it == enqueue_dispatchers_.end()) {
188 auto it_and_bool = enqueue_dispatchers_.emplace(
189 std::piecewise_construct,
190 std::forward_as_tuple(request->context_id()),
191 std::forward_as_tuple(
192 &stub_, cq_,
193 "/tensorflow.eager.EagerService/StreamingEnqueue"));
194 it = it_and_bool.first;
195 }
196 // TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
197 it->second.SendNextRequest(*request, response, std::move(done_wrapped));
198 } else {
199 Notification n;
200 Status status;
201 EnqueueAsync(call_opts, request, response,
202 [&n, &status](const Status& s) {
203 status.Update(s);
204 n.Notify();
205 });
206 n.WaitForNotification();
207 done_wrapped(status);
208 }
209 }
210
211 private:
212 ::grpc::GenericStub stub_;
213 const GrpcEagerClientThread* thread_;
214 const string target_;
215
216 ::grpc::CompletionQueue* cq_;
217
218 mutable mutex mu_;
219
220 std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
221 enqueue_dispatchers_ TF_GUARDED_BY(mu_);
222
callback_wrapper(StatusCallback done)223 StatusCallback callback_wrapper(StatusCallback done) {
224 Ref();
225 return [this, done = std::move(done)](const Status& status) {
226 done(status);
227 this->Unref();
228 };
229 }
230 };
231
232 class GrpcEagerClientCache : public EagerClientCache {
233 public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)234 explicit GrpcEagerClientCache(
235 std::shared_ptr<tensorflow::GrpcChannelCache> cache)
236 : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
237 for (int i = 0, end = threads_.size(); i < end; i++) {
238 threads_[i].reset(new GrpcEagerClientThread());
239 }
240 }
241
~GrpcEagerClientCache()242 ~GrpcEagerClientCache() override { threads_.clear(); }
243
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)244 Status GetClient(const string& target,
245 core::RefCountPtr<EagerClient>* client) override {
246 mutex_lock l(clients_mu_);
247 auto it = clients_.find(target);
248 if (it == clients_.end()) {
249 tensorflow::SharedGrpcChannelPtr shared =
250 cache_->FindWorkerChannel(target);
251 if (shared == nullptr) {
252 return errors::InvalidArgument("Client for target ", target,
253 " not found.");
254 }
255 int assigned_index = AssignClientToThread(target);
256 GrpcEagerClientThread* thread = threads_[assigned_index].get();
257 core::RefCountPtr<EagerClient> worker(
258 new GrpcEagerClient(shared, thread, target));
259 it = clients_.emplace(target, std::move(worker)).first;
260 }
261
262 it->second->Ref();
263 client->reset(it->second.get());
264 return Status::OK();
265 }
266
267 private:
268 mutex assignment_mu_;
269 std::unordered_map<std::string, size_t> target_assignments_
270 TF_GUARDED_BY(assignment_mu_);
271 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
272
AssignClientToThread(const string & target)273 size_t AssignClientToThread(const string& target) {
274 // Round-robin target assignment, but keeps the same target on the same
275 // polling thread always, as this is important for gRPC performance
276 mutex_lock lock(assignment_mu_);
277 auto it = target_assignments_.find(target);
278 if (it == target_assignments_.end()) {
279 it = target_assignments_
280 .insert(std::make_pair(
281 target, (next_round_robin_assignment_++) % threads_.size()))
282 .first;
283 }
284 return it->second;
285 }
286
287 std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
288 mutable mutex clients_mu_;
289 std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
290 TF_GUARDED_BY(clients_mu_);
291 std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
292 };
293
294 } // namespace
295
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)296 EagerClientCache* NewGrpcEagerClientCache(
297 std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
298 return new GrpcEagerClientCache(channel);
299 }
300
301 } // namespace eager
302 } // namespace tensorflow
303