1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17
18 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
24 #include "tensorflow/core/distributed_runtime/worker_interface.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/util/env_var.h"
28
29 namespace tensorflow {
30
31 namespace {
32
33 class GrpcWorkerCache : public WorkerCachePartial {
34 public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)35 explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
36 WorkerInterface* local_worker,
37 const string& local_target,
38 GrpcWorkerEnv* worker_env)
39 : local_target_(local_target),
40 local_worker_(local_worker),
41 channel_cache_(channel_cache),
42 worker_env_(worker_env),
43 next_round_robin_assignment_(0) {}
44
ListWorkers(std::vector<string> * workers) const45 void ListWorkers(std::vector<string>* workers) const override {
46 channel_cache_->ListWorkers(workers);
47 }
48
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const49 void ListWorkersInJob(const string& job_name,
50 std::vector<string>* workers) const override {
51 channel_cache_->ListWorkersInJob(job_name, workers);
52 }
53
GetOrCreateWorker(const string & target)54 WorkerInterface* GetOrCreateWorker(const string& target) override {
55 if (target == local_target_) {
56 return local_worker_;
57 } else {
58 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
59 if (!channel) {
60 return nullptr;
61 }
62 size_t index = AssignWorkerToThread(target);
63 return NewGrpcRemoteWorker(
64 channel, worker_env_->GetCompletionQueue(index),
65 worker_env_->GetThreadPool(), &logger_, target);
66 }
67 }
68
ReleaseWorker(const string & target,WorkerInterface * worker)69 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70 if (target == local_target_) {
71 CHECK_EQ(worker, local_worker_)
72 << "Releasing a worker that was not returned by this WorkerCache";
73 } else {
74 WorkerCacheInterface::ReleaseWorker(target, worker);
75 }
76 }
77
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)78 Status GetEagerClientCache(
79 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
80 eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
81 return OkStatus();
82 }
83
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)84 Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
85 coordination_client_cache) override {
86 coordination_client_cache->reset(
87 NewGrpcCoordinationClientCache(channel_cache_));
88 return OkStatus();
89 }
90
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)91 Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
92 coordination_client_cache) override {
93 #if defined(PLATFORM_GOOGLE)
94 coordination_client_cache->reset(
95 NewGrpcCoordinationClientCache(channel_cache_));
96 return Status::OK();
97 #else
98 return errors::Unimplemented(
99 "Coordination service in open source is not yet implemented.");
100 #endif
101 }
102
SetLogging(bool v)103 void SetLogging(bool v) override { logger_.SetLogging(v); }
104
ClearLogs()105 void ClearLogs() override { logger_.ClearLogs(); }
106
RetrieveLogs(int64_t step_id,StepStats * ss)107 bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
108 return logger_.RetrieveLogs(step_id, ss);
109 }
110
111 private:
AssignWorkerToThread(const string & target)112 size_t AssignWorkerToThread(const string& target) {
113 // Round-robin target assignment, but keeps the same target on the same
114 // polling thread always, as this is important for gRPC performance
115 mutex_lock lock(assignment_mu_);
116 auto it = target_assignments_.find(target);
117 if (it == target_assignments_.end()) {
118 it = target_assignments_
119 .insert(std::make_pair(target,
120 (next_round_robin_assignment_++) %
121 worker_env_->CompletionQueueSize()))
122 .first;
123 }
124 return it->second;
125 }
126
127 const string local_target_;
128 WorkerInterface* const local_worker_; // Not owned.
129 std::shared_ptr<GrpcChannelCache> channel_cache_;
130 WorkerCacheLogger logger_;
131 GrpcWorkerEnv* worker_env_; // Not owned
132
133 mutex assignment_mu_;
134 std::unordered_map<std::string, size_t> target_assignments_
135 TF_GUARDED_BY(assignment_mu_);
136 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
137 };
138
139 } // namespace
140
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)141 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
142 : threadpool_(new thread::ThreadPool(
143 Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
144 /*low_latency_hint=*/false, /*allocator=*/nullptr)),
145 threads_(num_completion_queues) {}
146
~GrpcWorkerEnv()147 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
148
GrpcWorkerCacheThread()149 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
150 thread_.reset(Env::Default()->StartThread(
151 ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
152 void* tag;
153 bool ok;
154 while (completion_queue_.Next(&tag, &ok)) {
155 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
156 callback_tag->OnCompleted(ok);
157 }
158 }));
159 }
160
~GrpcWorkerCacheThread()161 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
162 completion_queue_.Shutdown();
163 thread_.reset();
164 }
165
CreateGrpcWorkerEnv()166 GrpcWorkerEnv* CreateGrpcWorkerEnv() {
167 int num_cpus = port::NumSchedulableCPUs();
168 int64_t num_completion_queues;
169 Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
170 &num_completion_queues);
171 if (!status.ok()) {
172 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
173 }
174 int64_t num_threads;
175 status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
176 &num_threads);
177 if (!status.ok()) {
178 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
179 }
180 return new GrpcWorkerEnv(num_completion_queues, num_threads);
181 }
182
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env)183 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
184 GrpcWorkerEnv* worker_env) {
185 return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
186 worker_env);
187 }
188
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env,WorkerInterface * local_worker,const string & local_target)189 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
190 std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
191 WorkerInterface* local_worker, const string& local_target) {
192 return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
193 }
194
195 } // namespace tensorflow
196