• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17 
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/protobuf/eager_service.pb.h"
27 #include "tensorflow/core/util/env_var.h"
28 
29 namespace tensorflow {
30 namespace eager {
31 namespace {
32 
33 /*
34  * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
35  * true will turn on asynchronous execution of remote op. It means that when
36  * executing an op on a remote worker, client will not block on waiting
37  * for the response anymore. Using follow code as example:
38  *
39  * with tf.device('worker:0'):
40  *   a = tf.matmul(...)
41  *   b = tf.matmul(...)
42  * logging.into('Requests sent')    # Probably not executed yet
43  * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
44  *
45  * Streaming RPC will preserve order as well. So 'a' must be executed before
46  * 'b' on 'worker:0'.
47  *
48  * When turning on this feature, you should explicitly wait for some result
49  * from remote workers at the end of you python program. Otherwise, client may
50  * shutdown remote workers without waiting all pending ops.
51  *
52  * TODO(fishx): When exiting client, make sure all pending ops on remote workers
53  * are finished.
54  *
55  * TODO(b/139210648): Move this comment to eager/execute.py when this feature is
56  * on by default.
57  */
EnableStreaming()58 bool EnableStreaming() {
59   bool result;
60   TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
61                                  true, &result));
62   return result;
63 }
64 
65 // Ref-counted thread to handle callbacks for completed requests a GRPC
66 // completion queue. The thread might be shared by multiple eager clients, and
67 // each one of them should hold a reference count to ensure that the thread
68 // outlives the clients.
69 // To ensure that every tag in completion queue is processed, this thread also
70 // holds a reference to itself and always wait until ref count is one to exit.
71 class GrpcEagerClientThread : public core::RefCounted {
72  public:
GrpcEagerClientThread()73   GrpcEagerClientThread() {
74     // Hold a reference to ensure every completion tag gets processed.
75     Ref();
76     thread_.reset(Env::Default()->StartThread(
77         ThreadOptions(), "eager_client_thread", [this]() {
78           void* tag;
79           bool ok;
80           while (completion_queue_.Next(&tag, &ok)) {
81             VLOG(4) << "GrpcEagerClientThread got next tag";
82             GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
83             callback_tag->OnCompleted(ok);
84             VLOG(4) << "GrpcEagerClientThread blocking for next tag";
85             if (RefCountIsOne()) {
86               break;
87             }
88           }
89           VLOG(4) << "GrpcEagerClientThread exiting";
90           completion_queue_.Shutdown();
91           // `this` holds the final reference so cannot directly Unref here.
92           // Instead, schedule a separate thread to clean it up.
93           Env::Default()->SchedClosure([this]() { this->Unref(); });
94         }));
95   }
96 
~GrpcEagerClientThread()97   ~GrpcEagerClientThread() override {}
98 
completion_queue()99   ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
100 
101  private:
102   ::grpc::CompletionQueue completion_queue_;
103   std::unique_ptr<Thread> thread_;
104 };
105 
106 class GrpcEagerClient : public EagerClient {
107  public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread)108   GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
109                   GrpcEagerClientThread* thread)
110       : stub_(channel), thread_(thread) {
111     // Hold a reference to make sure the corresponding EagerClientThread
112     // outlives the client.
113     thread_->Ref();
114     cq_ = thread->completion_queue();
115   }
~GrpcEagerClient()116   ~GrpcEagerClient() override { thread_->Unref(); }
117 
118 #define CLIENT_METHOD(method)                                             \
119   void method##Async(const method##Request* request,                      \
120                      method##Response* response, StatusCallback done)     \
121       override {                                                          \
122     StatusCallback done_wrapped = callback_wrapper(std::move(done));      \
123     new RPCState<protobuf::Message>(                                      \
124         &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
125         response, std::move(done_wrapped), /*call_opts=*/nullptr,         \
126         /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true);   \
127   }
128 
129   CLIENT_METHOD(CreateContext);
130   CLIENT_METHOD(UpdateContext);
131   CLIENT_METHOD(Enqueue);
132   CLIENT_METHOD(WaitQueueDone);
133   CLIENT_METHOD(KeepAlive);
134 
135 #undef CLIENT_METHOD
136 
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)137   void CloseContextAsync(const CloseContextRequest* request,
138                          CloseContextResponse* response,
139                          StatusCallback done) override {
140     StatusCallback done_wrapped = callback_wrapper(std::move(done));
141     new RPCState<protobuf::Message>(
142         &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
143         response, std::move(done_wrapped), /*call_opts=*/nullptr,
144         /*threadpool=*/nullptr);
145 
146     VLOG(1) << "Sending RPC to close remote eager context "
147             << request->DebugString();
148 
149     mutex_lock l(mu_);
150     const auto& it = enqueue_dispatchers_.find(request->context_id());
151     if (it != enqueue_dispatchers_.end()) {
152       it->second.CancelCall();
153       enqueue_dispatchers_.erase(it);
154     } else if (EnableStreaming()) {
155       LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
156                  << " does not seem to exist.";
157     }
158   }
159 
StreamingEnqueueAsync(const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)160   void StreamingEnqueueAsync(const EnqueueRequest* request,
161                              EnqueueResponse* response,
162                              StatusCallback done) override {
163     StatusCallback done_wrapped = callback_wrapper(std::move(done));
164     if (EnableStreaming()) {
165       tf_shared_lock l(mu_);
166       auto it = enqueue_dispatchers_.find(request->context_id());
167       if (it == enqueue_dispatchers_.end()) {
168         auto it_and_bool = enqueue_dispatchers_.emplace(
169             std::piecewise_construct,
170             std::forward_as_tuple(request->context_id()),
171             std::forward_as_tuple(
172                 &stub_, cq_,
173                 "/tensorflow.eager.EagerService/StreamingEnqueue"));
174         it = it_and_bool.first;
175       }
176       it->second.SendNextRequest(*request, response, std::move(done_wrapped));
177     } else {
178       Notification n;
179       Status status;
180       EnqueueAsync(request, response, [&n, &status](const Status& s) {
181         status.Update(s);
182         n.Notify();
183       });
184       n.WaitForNotification();
185       done_wrapped(status);
186     }
187   }
188 
189  private:
190   ::grpc::GenericStub stub_;
191   const GrpcEagerClientThread* thread_;
192 
193   ::grpc::CompletionQueue* cq_;
194 
195   mutable mutex mu_;
196 
197   std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
198       enqueue_dispatchers_ GUARDED_BY(mu_);
199 
callback_wrapper(StatusCallback done)200   StatusCallback callback_wrapper(StatusCallback done) {
201     Ref();
202     return [this, done = std::move(done)](const Status& status) {
203       done(status);
204       this->Unref();
205     };
206   }
207 };
208 
209 class GrpcEagerClientCache : public EagerClientCache {
210  public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)211   explicit GrpcEagerClientCache(
212       std::shared_ptr<tensorflow::GrpcChannelCache> cache)
213       : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
214     for (int i = 0; i < threads_.size(); i++) {
215       threads_[i].reset(new GrpcEagerClientThread());
216     }
217   }
218 
~GrpcEagerClientCache()219   ~GrpcEagerClientCache() override { threads_.clear(); }
220 
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)221   Status GetClient(const string& target,
222                    core::RefCountPtr<EagerClient>* client) override {
223     auto it = clients_.find(target);
224     if (it == clients_.end()) {
225       tensorflow::SharedGrpcChannelPtr shared =
226           cache_->FindWorkerChannel(target);
227       if (shared == nullptr) {
228         return errors::InvalidArgument("Client for target ", target,
229                                        " not found.");
230       }
231       int assigned_index = AssignClientToThread(target);
232       GrpcEagerClientThread* thread = threads_[assigned_index].get();
233       core::RefCountPtr<EagerClient> worker(
234           new GrpcEagerClient(shared, thread));
235       it = clients_.emplace(target, std::move(worker)).first;
236     }
237 
238     it->second->Ref();
239     client->reset(it->second.get());
240     return Status::OK();
241   }
242 
243  private:
244   mutex assignment_mu_;
245   std::unordered_map<std::string, size_t> target_assignments_
246       GUARDED_BY(assignment_mu_);
247   size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
248 
AssignClientToThread(const string & target)249   size_t AssignClientToThread(const string& target) {
250     // Round-robin target assignment, but keeps the same target on the same
251     // polling thread always, as this is important for gRPC performance
252     mutex_lock lock(assignment_mu_);
253     auto it = target_assignments_.find(target);
254     if (it == target_assignments_.end()) {
255       it = target_assignments_
256                .insert(std::make_pair(
257                    target, (next_round_robin_assignment_++) % threads_.size()))
258                .first;
259     }
260     return it->second;
261   }
262 
263   std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
264   std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_;
265   std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
266 };
267 
268 }  // namespace
269 
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)270 EagerClientCache* NewGrpcEagerClientCache(
271     std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
272   return new GrpcEagerClientCache(channel);
273 }
274 
275 }  // namespace eager
276 }  // namespace tensorflow
277