1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
17
18 #include "grpcpp/generic/generic_stub.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_service.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/platform/env.h"
25 #include "tensorflow/core/protobuf/eager_service.pb.h"
26
27 namespace tensorflow {
28 namespace eager {
29 namespace {
30 class GrpcEagerClient : public EagerClient {
31 public:
GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr & channel,::grpc::CompletionQueue * cq)32 GrpcEagerClient(const tensorflow::SharedGrpcChannelPtr& channel,
33 ::grpc::CompletionQueue* cq)
34 : stub_(channel), cq_(cq) {}
~GrpcEagerClient()35 ~GrpcEagerClient() override {}
36
37 #define CLIENT_METHOD(method) \
38 void method##Async(const method##Request* request, \
39 method##Response* response, StatusCallback done) \
40 override { \
41 new RPCState<protobuf::Message>( \
42 &stub_, cq_, "/tensorflow.eager.EagerService/" #method, *request, \
43 response, std::move(done), nullptr, nullptr); \
44 }
45
46 CLIENT_METHOD(CreateContext);
47 CLIENT_METHOD(Enqueue);
48 CLIENT_METHOD(WaitQueueDone);
49 CLIENT_METHOD(KeepAlive);
50 CLIENT_METHOD(CloseContext);
51 CLIENT_METHOD(RegisterFunction);
52 CLIENT_METHOD(SendTensor);
53
54 #undef CLIENT_METHOD
55
56 private:
57 ::grpc::GenericStub stub_;
58 ::grpc::CompletionQueue* cq_;
59 };
60
61 class GrpcEagerClientCache : public EagerClientCache {
62 public:
GrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> cache)63 explicit GrpcEagerClientCache(
64 std::shared_ptr<tensorflow::GrpcChannelCache> cache)
65 : next_round_robin_assignment_(0), cache_(cache), threads_(4) {}
66
~GrpcEagerClientCache()67 ~GrpcEagerClientCache() override { threads_.clear(); }
68
GetClient(const string & target)69 EagerClient* GetClient(const string& target) override {
70 auto it = clients_.find(target);
71 if (it == clients_.end()) {
72 tensorflow::SharedGrpcChannelPtr shared =
73 cache_->FindWorkerChannel(target);
74 auto worker = std::unique_ptr<EagerClient>(new GrpcEagerClient(
75 shared, threads_[AssignClientToThread(target)].completion_queue()));
76
77 it = clients_.emplace(target, std::move(worker)).first;
78 }
79
80 return it->second.get();
81 }
82
83 private:
84 mutex assignment_mu_;
85 std::unordered_map<std::string, size_t> target_assignments_
86 GUARDED_BY(assignment_mu_);
87 size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
88
AssignClientToThread(const string & target)89 size_t AssignClientToThread(const string& target) {
90 // Round-robin target assignment, but keeps the same target on the same
91 // polling thread always, as this is important for gRPC performace
92 mutex_lock lock(assignment_mu_);
93 auto it = target_assignments_.find(target);
94 if (it == target_assignments_.end()) {
95 it = target_assignments_
96 .insert(std::make_pair(
97 target, (next_round_robin_assignment_++) % threads_.size()))
98 .first;
99 }
100 return it->second;
101 }
102
103 class GrpcEagerClientThread {
104 public:
GrpcEagerClientThread()105 GrpcEagerClientThread() {
106 thread_.reset(Env::Default()->StartThread(
107 ThreadOptions(), "eager_client_thread", [this]() {
108 void* tag;
109 bool ok;
110 while (completion_queue_.Next(&tag, &ok)) {
111 GrpcClientCQTag* callback_tag =
112 static_cast<GrpcClientCQTag*>(tag);
113 callback_tag->OnCompleted(ok);
114 }
115 }));
116 }
117
~GrpcEagerClientThread()118 ~GrpcEagerClientThread() {
119 completion_queue_.Shutdown();
120 thread_.reset();
121 }
122
completion_queue()123 ::grpc::CompletionQueue* completion_queue() { return &completion_queue_; }
124
125 private:
126 ::grpc::CompletionQueue completion_queue_;
127 std::unique_ptr<Thread> thread_;
128 }; // GrpcEagerClientThread
129
130 std::shared_ptr<tensorflow::GrpcChannelCache> cache_;
131 std::unordered_map<string, std::unique_ptr<EagerClient>> clients_;
132 std::vector<GrpcEagerClientThread> threads_;
133 };
134
135 } // namespace
136
NewGrpcEagerClientCache(std::shared_ptr<tensorflow::GrpcChannelCache> channel)137 EagerClientCache* NewGrpcEagerClientCache(
138 std::shared_ptr<tensorflow::GrpcChannelCache> channel) {
139 return new GrpcEagerClientCache(channel);
140 }
141
142 } // namespace eager
143 } // namespace tensorflow
144