1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/lib/core/refcount.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/env.h"
26 #include "tensorflow/core/protobuf/eager_service.pb.h"
27 #include "tensorflow/core/util/env_var.h"
28
29 namespace tensorflow {
30 namespace eager {
31 namespace {
32
33 /*
34 * Setting environment variable "TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE" to
35 * true will turn on asynchronous execution of remote op. It means that when
36 * executing an op on a remote worker, client will not block on waiting
37 * for the response anymore. Using follow code as example:
38 *
39 * with tf.device('worker:0'):
40 * a = tf.matmul(...)
41 * b = tf.matmul(...)
42 * logging.into('Requests sent') # Probably not executed yet
43 * logging.info('b: %s', b.numpy()) # Block until 'b' finished.
44 *
45 * Streaming RPC will preserve order as well. So 'a' must be executed before
46 * 'b' on 'worker:0'.
47 *
48 * When turning on this feature, you should explicitly wait for some result
49 * from remote workers at the end of you python program. Otherwise, client may
50 * shutdown remote workers without waiting all pending ops.
51 *
52 * TODO(fishx): When exiting client, make sure all pending ops on remote workers
53 * are finished.
54 *
55 * TODO(b/139210648): Move this comment to eager/execute.py when this feature is
56 * on by default.
57 */
EnableStreaming()58 bool EnableStreaming() {
59 bool result;
60 TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_EAGER_CLIENT_STREAMING_ENQUEUE",
61 true, &result));
62 return result;
63 }
64
65 // Ref-counted thread to handle callbacks for completed requests a GRPC
66 // completion queue. The thread might be shared by multiple eager clients, and
67 // each one of them should hold a reference count to ensure that the thread
68 // outlives the clients.
69 // To ensure that every tag in completion queue is processed, this thread also
70 // holds a reference to itself and always wait until ref count is one to exit.
71 class GrpcEagerClientThread : public core::RefCounted {
72 public:
GrpcEagerClientThread()73 GrpcEagerClientThread() {
74 // Hold a reference to ensure every completion tag gets processed.
75 Ref();
76 thread_.reset(Env::Default()->StartThread(
77 ThreadOptions(), "eager_client_thread", [this]() {
78 void* tag;
79 bool ok;
80 while (completion_queue_.Next(&tag, &ok)) {
81 VLOG(4) << "GrpcEagerClientThread got next tag";
82 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
83 callback_tag->OnCompleted(ok);
84 VLOG(4) << "GrpcEagerClientThread blocking for next tag";
85 if (RefCountIsOne()) {
86 break;
87 }
88 }
89 VLOG(4) << "GrpcEagerClientThread exiting";
90 completion_queue_.Shutdown();
91 // `this` holds the final reference so cannot directly Unref here.
92 // Instead, schedule a separate thread to clean it up.
93 Env::Default()->SchedClosure([this]() { this->Unref(); });
94 }));
95 }
96
~GrpcEagerClientThread()97 ~GrpcEagerClientThread() override {}
98
completion_queue()99 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
100
101 private:
102 ::grpc::CompletionQueue completion_queue_;
103 std::unique_ptr<Thread> thread_;
104 };
105
106 class GrpcEagerClient : public EagerClient {
107 public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,GrpcEagerClientThread * thread)108 GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
109 GrpcEagerClientThread* thread)
110 : stub_(channel), thread_(thread) {
111 // Hold a reference to make sure the corresponding EagerClientThread
112 // outlives the client.
113 thread_->Ref();
114 cq_ = thread->completion_queue();
115 }
~GrpcEagerClient()116 ~GrpcEagerClient() override { thread_->Unref(); }
117
118 #define CLIENT_METHOD(method) \
119 void method##Async(const method##Request* request, \
120 method##Response* response, StatusCallback done) \
121 override { \
122 StatusCallback done_wrapped = callback_wrapper(std::move(done)); \
123 new RPCState<protobuf::Message>( \
124 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
125 response, std::move(done_wrapped), /*call_opts=*/nullptr, \
126 /*threadpool=*/nullptr, /*max_retries=*/0, /*fail_fast=*/true); \
127 }
128
129 CLIENT_METHOD(CreateContext);
130 CLIENT_METHOD(UpdateContext);
131 CLIENT_METHOD(Enqueue);
132 CLIENT_METHOD(WaitQueueDone);
133 CLIENT_METHOD(KeepAlive);
134
135 #undef CLIENT_METHOD
136
CloseContextAsync(const CloseContextRequest * request,CloseContextResponse * response,StatusCallback done)137 void CloseContextAsync(const CloseContextRequest* request,
138 CloseContextResponse* response,
139 StatusCallback done) override {
140 StatusCallback done_wrapped = callback_wrapper(std::move(done));
141 new RPCState<protobuf::Message>(
142 &stub_, cq_, "/tensorflow.eager.EagerService/CloseContext", *request,
143 response, std::move(done_wrapped), /*call_opts=*/nullptr,
144 /*threadpool=*/nullptr);
145
146 VLOG(1) << "Sending RPC to close remote eager context "
147 << request->DebugString();
148
149 mutex_lock l(mu_);
150 const auto& it = enqueue_dispatchers_.find(request->context_id());
151 if (it != enqueue_dispatchers_.end()) {
152 it->second.CancelCall();
153 enqueue_dispatchers_.erase(it);
154 } else if (EnableStreaming()) {
155 LOG(ERROR) << "Remote EagerContext with id " << request->context_id()
156 << " does not seem to exist.";
157 }
158 }
159
StreamingEnqueueAsync(const EnqueueRequest * request,EnqueueResponse * response,StatusCallback done)160 void StreamingEnqueueAsync(const EnqueueRequest* request,
161 EnqueueResponse* response,
162 StatusCallback done) override {
163 StatusCallback done_wrapped = callback_wrapper(std::move(done));
164 if (EnableStreaming()) {
165 tf_shared_lock l(mu_);
166 auto it = enqueue_dispatchers_.find(request->context_id());
167 if (it == enqueue_dispatchers_.end()) {
168 auto it_and_bool = enqueue_dispatchers_.emplace(
169 std::piecewise_construct,
170 std::forward_as_tuple(request->context_id()),
171 std::forward_as_tuple(
172 &stub_, cq_,
173 "/tensorflow.eager.EagerService/StreamingEnqueue"));
174 it = it_and_bool.first;
175 }
176 it->second.SendNextRequest(*request, response, std::move(done_wrapped));
177 } else {
178 Notification n;
179 Status status;
180 EnqueueAsync(request, response, [&n, &status](const Status& s) {
181 status.Update(s);
182 n.Notify();
183 });
184 n.WaitForNotification();
185 done_wrapped(status);
186 }
187 }
188
189 private:
190 ::grpc::GenericStub stub_;
191 const GrpcEagerClientThread* thread_;
192
193 ::grpc::CompletionQueue* cq_;
194
195 mutable mutex mu_;
196
197 std::unordered_map<uint64, StreamingRPCDispatcher<EnqueueResponse>>
198 enqueue_dispatchers_ GUARDED_BY(mu_);
199
callback_wrapper(StatusCallback done)200 StatusCallback callback_wrapper(StatusCallback done) {
201 Ref();
202 return [this, done = std::move(done)](const Status& status) {
203 done(status);
204 this->Unref();
205 };
206 }
207 };
208
209 class GrpcEagerClientCache : public EagerClientCache {
210 public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)211 explicit GrpcEagerClientCache(
212 std::shared_ptr<tensorflow::GrpcChannelCache> cache)
213 : next_round_robin_assignment_(0), cache_(cache), threads_(4) {
214 for (int i = 0; i < threads_.size(); i++) {
215 threads_[i].reset(new GrpcEagerClientThread());
216 }
217 }
218
~GrpcEagerClientCache()219 ~GrpcEagerClientCache() override { threads_.clear(); }
220
GetClient(const string & target,core::RefCountPtr<EagerClient> * client)221 Status GetClient(const string& target,
222 core::RefCountPtr<EagerClient>* client) override {
223 auto it = clients_.find(target);
224 if (it == clients_.end()) {
225 tensorflow::SharedGrpcChannelPtr shared =
226 cache_->FindWorkerChannel(target);
227 if (shared == nullptr) {
228 return errors::InvalidArgument("Client for target ", target,
229 " not found.");
230 }
231 int assigned_index = AssignClientToThread(target);
232 GrpcEagerClientThread* thread = threads_[assigned_index].get();
233 core::RefCountPtr<EagerClient> worker(
234 new GrpcEagerClient(shared, thread));
235 it = clients_.emplace(target, std::move(worker)).first;
236 }
237
238 it->second->Ref();
239 client->reset(it->second.get());
240 return Status::OK();
241 }
242
243 private:
244 mutex assignment_mu_;
245 std::unordered_map<std::string, size_t> target_assignments_
246 GUARDED_BY(assignment_mu_);
247 size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
248
AssignClientToThread(const string & target)249 size_t AssignClientToThread(const string& target) {
250 // Round-robin target assignment, but keeps the same target on the same
251 // polling thread always, as this is important for gRPC performance
252 mutex_lock lock(assignment_mu_);
253 auto it = target_assignments_.find(target);
254 if (it == target_assignments_.end()) {
255 it = target_assignments_
256 .insert(std::make_pair(
257 target, (next_round_robin_assignment_++) % threads_.size()))
258 .first;
259 }
260 return it->second;
261 }
262
263 std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
264 std::unordered_map<string, core::RefCountPtr<EagerClient>> clients_;
265 std::vector<core::RefCountPtr<GrpcEagerClientThread>> threads_;
266 };
267
268 } // namespace
269
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)270 EagerClientCache* NewGrpcEagerClientCache(
271 std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
272 return new GrpcEagerClientCache(channel);
273 }
274
275 } // namespace eager
276 } // namespace tensorflow
277