1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17
18 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
24 #include "tensorflow/core/distributed_runtime/worker_interface.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/util/env_var.h"
28
29 namespace tensorflow {
30
31 namespace {
32
33 class GrpcWorkerCache : public WorkerCachePartial {
34 public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)35 explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
36 WorkerInterface* local_worker,
37 const string& local_target,
38 GrpcWorkerEnv* worker_env)
39 : local_target_(local_target),
40 local_worker_(local_worker),
41 channel_cache_(channel_cache),
42 worker_env_(worker_env),
43 next_round_robin_assignment_(0) {}
44
ListWorkers(std::vector<string> * workers) const45 void ListWorkers(std::vector<string>* workers) const override {
46 channel_cache_->ListWorkers(workers);
47 }
48
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const49 void ListWorkersInJob(const string& job_name,
50 std::vector<string>* workers) const override {
51 channel_cache_->ListWorkersInJob(job_name, workers);
52 }
53
GetOrCreateWorker(const string & target)54 WorkerInterface* GetOrCreateWorker(const string& target) override {
55 if (target == local_target_) {
56 return local_worker_;
57 } else {
58 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
59 if (!channel) {
60 return nullptr;
61 }
62 size_t index = AssignWorkerToThread(target);
63 return NewGrpcRemoteWorker(
64 channel, worker_env_->GetCompletionQueue(index),
65 worker_env_->GetThreadPool(), &logger_, target);
66 }
67 }
68
ReleaseWorker(const string & target,WorkerInterface * worker)69 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70 if (target == local_target_) {
71 CHECK_EQ(worker, local_worker_)
72 << "Releasing a worker that was not returned by this WorkerCache";
73 } else {
74 WorkerCacheInterface::ReleaseWorker(target, worker);
75 }
76 }
77
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)78 Status GetEagerClientCache(
79 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
80 eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
81 return Status::OK();
82 }
83
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)84 Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
85 coordination_client_cache) override {
86 #if defined(PLATFORM_GOOGLE)
87 coordination_client_cache->reset(
88 NewGrpcCoordinationClientCache(channel_cache_));
89 return Status::OK();
90 #else
91 return errors::Unimplemented(
92 "Coordination service in open source is not yet implemented.");
93 #endif
94 }
95
SetLogging(bool v)96 void SetLogging(bool v) override { logger_.SetLogging(v); }
97
ClearLogs()98 void ClearLogs() override { logger_.ClearLogs(); }
99
RetrieveLogs(int64_t step_id,StepStats * ss)100 bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
101 return logger_.RetrieveLogs(step_id, ss);
102 }
103
104 private:
AssignWorkerToThread(const string & target)105 size_t AssignWorkerToThread(const string& target) {
106 // Round-robin target assignment, but keeps the same target on the same
107 // polling thread always, as this is important for gRPC performance
108 mutex_lock lock(assignment_mu_);
109 auto it = target_assignments_.find(target);
110 if (it == target_assignments_.end()) {
111 it = target_assignments_
112 .insert(std::make_pair(target,
113 (next_round_robin_assignment_++) %
114 worker_env_->CompletionQueueSize()))
115 .first;
116 }
117 return it->second;
118 }
119
120 const string local_target_;
121 WorkerInterface* const local_worker_; // Not owned.
122 std::shared_ptr<GrpcChannelCache> channel_cache_;
123 WorkerCacheLogger logger_;
124 GrpcWorkerEnv* worker_env_; // Not owned
125
126 mutex assignment_mu_;
127 std::unordered_map<std::string, size_t> target_assignments_
128 TF_GUARDED_BY(assignment_mu_);
129 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
130 };
131
132 } // namespace
133
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)134 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
135 : threadpool_(new thread::ThreadPool(
136 Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
137 /*low_latency_hint=*/false, /*allocator=*/nullptr)),
138 threads_(num_completion_queues) {}
139
~GrpcWorkerEnv()140 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
141
GrpcWorkerCacheThread()142 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
143 thread_.reset(Env::Default()->StartThread(
144 ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
145 void* tag;
146 bool ok;
147 while (completion_queue_.Next(&tag, &ok)) {
148 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
149 callback_tag->OnCompleted(ok);
150 }
151 }));
152 }
153
~GrpcWorkerCacheThread()154 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
155 completion_queue_.Shutdown();
156 thread_.reset();
157 }
158
CreateGrpcWorkerEnv()159 GrpcWorkerEnv* CreateGrpcWorkerEnv() {
160 int num_cpus = port::NumSchedulableCPUs();
161 int64_t num_completion_queues;
162 Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
163 &num_completion_queues);
164 if (!status.ok()) {
165 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
166 }
167 int64_t num_threads;
168 status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
169 &num_threads);
170 if (!status.ok()) {
171 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
172 }
173 return new GrpcWorkerEnv(num_completion_queues, num_threads);
174 }
175
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env)176 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
177 GrpcWorkerEnv* worker_env) {
178 return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
179 worker_env);
180 }
181
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env,WorkerInterface * local_worker,const string & local_target)182 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
183 std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
184 WorkerInterface* local_worker, const string& local_target) {
185 return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
186 }
187
188 } // namespace tensorflow
189