• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17 
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/call_options.h"
20 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/protobuf/eager_service.pb.h"
28 #include "tensorflow/core/util/env_var.h"
29 
30 namespace tensorflow {
31 namespace eager {
32 namespace {
33 
34 /*
35  * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
36  * true will turn on asynchronous execution of remote op. It means that when
37  * executing an op on a remote worker, client will not block on waiting
38  * for the response anymore. Using follow code as example:
39  *
40  * with tf.device('worker:0'):
41  *   a = tf.matmul(...)
42  *   b = tf.matmul(...)
43  * logging.into('Requests sent')    # Probably not executed yet
44  * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
45  *
46  * Streaming RPC will preserve order as well. So 'a' must be executed before
47  * 'b' on 'worker:0'.
48  *
49  * When turning on this feature, you should explicitly wait for some result
50  * from remote workers at the end of you python program. Otherwise, client may
51  * shutdown remote workers without waiting all pending ops.
52  *
53  * TODO(fishx): When exiting client, make sure all pending ops on remote workers
54  * are finished.
55  */
EnableStreaming()56 bool EnableStreaming() {
57   bool result;
58   TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
59                                  true, &result));
60   return result;
61 }
62 
63 // Ref-counted thread to handle callbacks for completed requests a GRPC
64 // completion queue. The thread might be shared by multiple eager clients, and
65 // each one of them should hold a reference count to ensure that the thread
66 // outlives the clients.
67 // To ensure that every tag in completion queue is processed, this thread also
68 // holds a reference to itself and always wait until ref count is one to exit.
69 class GrpcEagerClientThread : public core::RefCounted {
70  public:
GrpcEagerClientThread()71   GrpcEagerClientThread() {
72     // Hold a reference to ensure every completion tag gets processed.
73     Ref();
74     thread_.reset(Env::Default()->StartThread(
75         ThreadOptions(), "eager_client_thread", [this]() {
76           void* tag;
77           bool ok;
78           while (completion_queue_.Next(&tag, &ok)) {
79             VLOG(4) << "GrpcEagerClientThread got next tag";
80             GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
81             callback_tag->OnCompleted(ok);
82             VLOG(4) << "GrpcEagerClientThread blocking for next tag";
83             if (RefCountIsOne()) {
84               break;
85             }
86           }
87           VLOG(4) << "GrpcEagerClientThread exiting";
88           completion_queue_.Shutdown();
89           // `this` holds the final reference so cannot directly Unref here.
90           // Instead, schedule a separate thread to clean it up.
91           Env::Default()->SchedClosure([this]() { this->Unref(); });
92         }));
93   }
94 
~GrpcEagerClientThread()95   ~GrpcEagerClientThread() override {}
96 
completion_queue()97   ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
98 
99  private:
100   ::grpc::CompletionQueue completion_queue_;
101   std::unique_ptr<Thread> thread_;
102 };
103 
104 class GrpcEagerClient : public EagerClient {
105  public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread,const string & target)106   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
107                   GrpcEagerClientThread* thread, const string& target)
108       : stub_(channel), thread_(thread), target_(target) {
109     // Hold a reference to make sure the corresponding EagerClientThread
110     // outlives the client.
111     thread_->Ref();
112     cq_ = thread->completion_queue();
113   }
~GrpcEagerClient()114   ~GrpcEagerClient() override { thread_->Unref(); }
115 
allow_multiple_pending_requests() const116   bool allow_multiple_pending_requests() const override {
117     return EnableStreaming();
118   }
119 
120 #define CLIENT_METHOD(method)                                             \
121   void method##Async(const method##Request* request,                      \
122                      method##Response* response, StatusCallback done)     \
123       override {                                                          \
124     StatusCallback done_wrapped = callback_wrapper(std::move(done));      \
125     new RPCState<protobuf::Message>(                                      \
126         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
127         response, std::move(done_wrapped), /*call_opts=*/nullptr,         \
128         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,    \
129         &target_);                                                        \
130   }
131 
132   CLIENT_METHOD(CreateContext);
133   CLIENT_METHOD(UpdateContext);
134   CLIENT_METHOD(WaitQueueDone);
135   CLIENT_METHOD(KeepAlive);
136 
137 #undef CLIENT_METHOD
138 
139 #define CLIENT_CANCELABLE_METHOD(method)                                      \
140   void method##Async(CallOptions* call_opts, const method##Request* request,  \
141                      method##Response* response, StatusCallback done)         \
142       override {                                                              \
143     StatusCallback done_wrapped = callback_wrapper(std::move(done));          \
144     new RPCState<protobuf::Message>(                                          \
145         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request,     \
146         response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
147         /*max_retries=*/0, /*fail_fast=*/true, &target_);                     \
148   }
149 
150   CLIENT_CANCELABLE_METHOD(Enqueue);
151   CLIENT_CANCELABLE_METHOD(RunComponentFunction);
152 
153 #undef CLIENT_CANCELABLE_METHOD
154 
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)155   void CloseContextAsync(const CloseContextRequest* request,
156                          CloseContextResponse* response,
157                          StatusCallback done) override {
158     StatusCallback done_wrapped = callback_wrapper(std::move(done));
159     new RPCState<protobuf::Message>(
160         &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
161         response, std::move(done_wrapped), /*call_opts=*/nullptr,
162         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
163         &target_);
164 
165     VLOG(1) << "Sending RPC to close remote eager context "
166             << request->DebugString();
167 
168     mutex_lock l(mu_);
169     const auto& it = enqueue_dispatchers_.find(request->context_id());
170     if (it != enqueue_dispatchers_.end()) {
171       it->second.CancelCall();
172       enqueue_dispatchers_.erase(it);
173     } else if (EnableStreaming()) {
174       LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
175                  << " does not seem to exist.";
176     }
177   }
178 
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)179   void StreamingEnqueueAsync(CallOptions* call_opts,
180                              const EnqueueRequest* request,
181                              EnqueueResponse* response,
182                              StatusCallback done) override {
183     StatusCallback done_wrapped = callback_wrapper(std::move(done));
184     if (EnableStreaming()) {
185       mutex_lock l(mu_);
186       auto it = enqueue_dispatchers_.find(request->context_id());
187       if (it == enqueue_dispatchers_.end()) {
188         auto it_and_bool = enqueue_dispatchers_.emplace(
189             std::piecewise_construct,
190             std::forward_as_tuple(request->context_id()),
191             std::forward_as_tuple(
192                 &stub_, cq_,
193                 "/tensorflow.eager.EagerService/StreamingEnqueue"));
194         it = it_and_bool.first;
195       }
196       // TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
197       it->second.SendNextRequest(*request, response, std::move(done_wrapped));
198     } else {
199       Notification n;
200       Status status;
201       EnqueueAsync(call_opts, request, response,
202                    [&n, &status](const Status& s) {
203                      status.Update(s);
204                      n.Notify();
205                    });
206       n.WaitForNotification();
207       done_wrapped(status);
208     }
209   }
210 
211  private:
212   ::grpc::GenericStub stub_;
213   const GrpcEagerClientThread* thread_;
214   const string target_;
215 
216   ::grpc::CompletionQueue* cq_;
217 
218   mutable mutex mu_;
219 
220   std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
221       enqueue_dispatchers_ TF_GUARDED_BY(mu_);
222 
callback_wrapper(StatusCallback done)223   StatusCallback callback_wrapper(StatusCallback done) {
224     Ref();
225     return [this, done = std::move(done)](const Status& status) {
226       done(status);
227       this->Unref();
228     };
229   }
230 };
231 
232 class GrpcEagerClientCache : public EagerClientCache {
233  public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)234   explicit GrpcEagerClientCache(
235       std::shared_ptr<tensorflow::GrpcChannelCache> cache)
236       : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
237     for (int i = 0, end = threads_.size(); i < end; i++) {
238       threads_[i].reset(new GrpcEagerClientThread());
239     }
240   }
241 
~GrpcEagerClientCache()242   ~GrpcEagerClientCache() override { threads_.clear(); }
243 
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)244   Status GetClient(const string& target,
245                    core::RefCountPtr<EagerClient>* client) override {
246     mutex_lock l(clients_mu_);
247     auto it = clients_.find(target);
248     if (it == clients_.end()) {
249       tensorflow::SharedGrpcChannelPtr shared =
250           cache_->FindWorkerChannel(target);
251       if (shared == nullptr) {
252         return errors::InvalidArgument("Client for target ", target,
253                                        " not found.");
254       }
255       int assigned_index = AssignClientToThread(target);
256       GrpcEagerClientThread* thread = threads_[assigned_index].get();
257       core::RefCountPtr<EagerClient> worker(
258           new GrpcEagerClient(shared, thread, target));
259       it = clients_.emplace(target, std::move(worker)).first;
260     }
261 
262     it->second->Ref();
263     client->reset(it->second.get());
264     return Status::OK();
265   }
266 
267  private:
268   mutex assignment_mu_;
269   std::unordered_map<std::string, size_t> target_assignments_
270       TF_GUARDED_BY(assignment_mu_);
271   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
272 
AssignClientToThread(const string & target)273   size_t AssignClientToThread(const string& target) {
274     // Round-robin target assignment, but keeps the same target on the same
275     // polling thread always, as this is important for gRPC performance
276     mutex_lock lock(assignment_mu_);
277     auto it = target_assignments_.find(target);
278     if (it == target_assignments_.end()) {
279       it = target_assignments_
280                .insert(std::make_pair(
281                    target, (next_round_robin_assignment_++) % threads_.size()))
282                .first;
283     }
284     return it->second;
285   }
286 
287   std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
288   mutable mutex clients_mu_;
289   std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
290       TF_GUARDED_BY(clients_mu_);
291   std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
292 };
293 
294 }  // namespace
295 
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)296 EagerClientCache* NewGrpcEagerClientCache(
297     std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
298   return new GrpcEagerClientCache(channel);
299 }
300 
301 }  // namespace eager
302 }  // namespace tensorflow
303