• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17 
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/call_options.h"
20 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
24 #include "tensorflow/core/lib/core/refcount.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/protobuf/eager_service.pb.h"
28 #include "tensorflow/core/util/env_var.h"
29 
30 namespace tensorflow {
31 namespace eager {
32 namespace {
33 
34 /*
35  * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
36  * true will turn on asynchronous execution of remote op. It means that when
37  * executing an op on a remote worker, client will not block on waiting
38  * for the response anymore. Using follow code as example:
39  *
40  * with tf.device('worker:0'):
41  *   a = tf.matmul(...)
42  *   b = tf.matmul(...)
43  * logging.into('Requests sent')    # Probably not executed yet
44  * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
45  *
46  * Streaming RPC will preserve order as well. So 'a' must be executed before
47  * 'b' on 'worker:0'.
48  *
49  * When turning on this feature, you should explicitly wait for some result
50  * from remote workers at the end of you python program. Otherwise, client may
51  * shutdown remote workers without waiting all pending ops.
52  *
53  * TODO(fishx): When exiting client, make sure all pending ops on remote workers
54  * are finished.
55  *
56  * TODO(b/139210648): Move this comment to eager/execute.py when this feature is
57  * on by default.
58  */
EnableStreaming()59 bool EnableStreaming() {
60   bool result;
61   TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
62                                  true, &result));
63   return result;
64 }
65 
66 // Ref-counted thread to handle callbacks for completed requests a GRPC
67 // completion queue. The thread might be shared by multiple eager clients, and
68 // each one of them should hold a reference count to ensure that the thread
69 // outlives the clients.
70 // To ensure that every tag in completion queue is processed, this thread also
71 // holds a reference to itself and always wait until ref count is one to exit.
72 class GrpcEagerClientThread : public core::RefCounted {
73  public:
GrpcEagerClientThread()74   GrpcEagerClientThread() {
75     // Hold a reference to ensure every completion tag gets processed.
76     Ref();
77     thread_.reset(Env::Default()->StartThread(
78         ThreadOptions(), "eager_client_thread", [this]() {
79           void* tag;
80           bool ok;
81           while (completion_queue_.Next(&tag, &ok)) {
82             VLOG(4) << "GrpcEagerClientThread got next tag";
83             GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
84             callback_tag->OnCompleted(ok);
85             VLOG(4) << "GrpcEagerClientThread blocking for next tag";
86             if (RefCountIsOne()) {
87               break;
88             }
89           }
90           VLOG(4) << "GrpcEagerClientThread exiting";
91           completion_queue_.Shutdown();
92           // `this` holds the final reference so cannot directly Unref here.
93           // Instead, schedule a separate thread to clean it up.
94           Env::Default()->SchedClosure([this]() { this->Unref(); });
95         }));
96   }
97 
~GrpcEagerClientThread()98   ~GrpcEagerClientThread() override {}
99 
completion_queue()100   ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
101 
102  private:
103   ::grpc::CompletionQueue completion_queue_;
104   std::unique_ptr<Thread> thread_;
105 };
106 
107 class GrpcEagerClient : public EagerClient {
108  public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread,const string & target)109   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
110                   GrpcEagerClientThread* thread, const string& target)
111       : stub_(channel), thread_(thread), target_(target) {
112     // Hold a reference to make sure the corresponding EagerClientThread
113     // outlives the client.
114     thread_->Ref();
115     cq_ = thread->completion_queue();
116   }
~GrpcEagerClient()117   ~GrpcEagerClient() override { thread_->Unref(); }
118 
allow_multiple_pending_requests() const119   bool allow_multiple_pending_requests() const override {
120     return EnableStreaming();
121   }
122 
123 #define CLIENT_METHOD(method)                                             \
124   void method##Async(const method##Request* request,                      \
125                      method##Response* response, StatusCallback done)     \
126       override {                                                          \
127     StatusCallback done_wrapped = callback_wrapper(std::move(done));      \
128     new RPCState<protobuf::Message>(                                      \
129         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
130         response, std::move(done_wrapped), /*call_opts=*/nullptr,         \
131         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,    \
132         &target_);                                                        \
133   }
134 
135   CLIENT_METHOD(CreateContext);
136   CLIENT_METHOD(UpdateContext);
137   CLIENT_METHOD(WaitQueueDone);
138   CLIENT_METHOD(KeepAlive);
139 
140 #undef CLIENT_METHOD
141 
142 #define CLIENT_CANCELABLE_METHOD(method)                                      \
143   void method##Async(CallOptions* call_opts, const method##Request* request,  \
144                      method##Response* response, StatusCallback done)         \
145       override {                                                              \
146     StatusCallback done_wrapped = callback_wrapper(std::move(done));          \
147     new RPCState<protobuf::Message>(                                          \
148         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request,     \
149         response, std::move(done_wrapped), call_opts, /*threadpool=*/nullptr, \
150         /*max_retries=*/0, /*fail_fast=*/true, &target_);                     \
151   }
152 
153   CLIENT_CANCELABLE_METHOD(Enqueue);
154   CLIENT_CANCELABLE_METHOD(RunComponentFunction);
155 
156 #undef CLIENT_CANCELABLE_METHOD
157 
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)158   void CloseContextAsync(const CloseContextRequest* request,
159                          CloseContextResponse* response,
160                          StatusCallback done) override {
161     StatusCallback done_wrapped = callback_wrapper(std::move(done));
162     new RPCState<protobuf::Message>(
163         &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
164         response, std::move(done_wrapped), /*call_opts=*/nullptr,
165         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true,
166         &target_);
167 
168     VLOG(1) << "Sending RPC to close remote eager context "
169             << request->DebugString();
170 
171     mutex_lock l(mu_);
172     const auto& it = enqueue_dispatchers_.find(request->context_id());
173     if (it != enqueue_dispatchers_.end()) {
174       it->second.CancelCall();
175       enqueue_dispatchers_.erase(it);
176     } else if (EnableStreaming()) {
177       LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
178                  << " does not seem to exist.";
179     }
180   }
181 
StreamingEnqueueAsync(CallOptions * call_opts,const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)182   void StreamingEnqueueAsync(CallOptions* call_opts,
183                              const EnqueueRequest* request,
184                              EnqueueResponse* response,
185                              StatusCallback done) override {
186     StatusCallback done_wrapped = callback_wrapper(std::move(done));
187     if (EnableStreaming()) {
188       mutex_lock l(mu_);
189       auto it = enqueue_dispatchers_.find(request->context_id());
190       if (it == enqueue_dispatchers_.end()) {
191         auto it_and_bool = enqueue_dispatchers_.emplace(
192             std::piecewise_construct,
193             std::forward_as_tuple(request->context_id()),
194             std::forward_as_tuple(
195                 &stub_, cq_,
196                 "/tensorflow.eager.EagerService/StreamingEnqueue"));
197         it = it_and_bool.first;
198       }
199       // TODO(haoyuzhang): Consider supporting cancellation for streaming RPC?
200       it->second.SendNextRequest(*request, response, std::move(done_wrapped));
201     } else {
202       Notification n;
203       Status status;
204       EnqueueAsync(call_opts, request, response,
205                    [&n, &status](const Status& s) {
206                      status.Update(s);
207                      n.Notify();
208                    });
209       n.WaitForNotification();
210       done_wrapped(status);
211     }
212   }
213 
214  private:
215   ::grpc::GenericStub stub_;
216   const GrpcEagerClientThread* thread_;
217   const string target_;
218 
219   ::grpc::CompletionQueue* cq_;
220 
221   mutable mutex mu_;
222 
223   std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
224       enqueue_dispatchers_ TF_GUARDED_BY(mu_);
225 
callback_wrapper(StatusCallback done)226   StatusCallback callback_wrapper(StatusCallback done) {
227     Ref();
228     return [this, done = std::move(done)](const Status& status) {
229       done(status);
230       this->Unref();
231     };
232   }
233 };
234 
235 class GrpcEagerClientCache : public EagerClientCache {
236  public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)237   explicit GrpcEagerClientCache(
238       std::shared_ptr<tensorflow::GrpcChannelCache> cache)
239       : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
240     for (int i = 0, end = threads_.size(); i < end; i++) {
241       threads_[i].reset(new GrpcEagerClientThread());
242     }
243   }
244 
~GrpcEagerClientCache()245   ~GrpcEagerClientCache() override { threads_.clear(); }
246 
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)247   Status GetClient(const string& target,
248                    core::RefCountPtr<EagerClient>* client) override {
249     mutex_lock l(clients_mu_);
250     auto it = clients_.find(target);
251     if (it == clients_.end()) {
252       tensorflow::SharedGrpcChannelPtr shared =
253           cache_->FindWorkerChannel(target);
254       if (shared == nullptr) {
255         return errors::InvalidArgument("Client for target ", target,
256                                        " not found.");
257       }
258       int assigned_index = AssignClientToThread(target);
259       GrpcEagerClientThread* thread = threads_[assigned_index].get();
260       core::RefCountPtr<EagerClient> worker(
261           new GrpcEagerClient(shared, thread, target));
262       it = clients_.emplace(target, std::move(worker)).first;
263     }
264 
265     it->second->Ref();
266     client->reset(it->second.get());
267     return Status::OK();
268   }
269 
270  private:
271   mutex assignment_mu_;
272   std::unordered_map<std::string, size_t> target_assignments_
273       TF_GUARDED_BY(assignment_mu_);
274   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
275 
AssignClientToThread(const string & target)276   size_t AssignClientToThread(const string& target) {
277     // Round-robin target assignment, but keeps the same target on the same
278     // polling thread always, as this is important for gRPC performance
279     mutex_lock lock(assignment_mu_);
280     auto it = target_assignments_.find(target);
281     if (it == target_assignments_.end()) {
282       it = target_assignments_
283                .insert(std::make_pair(
284                    target, (next_round_robin_assignment_++) % threads_.size()))
285                .first;
286     }
287     return it->second;
288   }
289 
290   std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
291   mutable mutex clients_mu_;
292   std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_
293       TF_GUARDED_BY(clients_mu_);
294   std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
295 };
296 
297 }  // namespace
298 
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)299 EagerClientCache* NewGrpcEagerClientCache(
300     std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
301   return new GrpcEagerClientCache(channel);
302 }
303 
304 }  // namespace eager
305 }  // namespace tensorflow
306