1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17
18 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
19 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
21 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
23 #include "tensorflow/core/distributed_runtime/worker_interface.h"
24 #include "tensorflow/core/platform/cpu_info.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/util/env_var.h"
27
28 namespace tensorflow {
29
30 namespace {
31
32 class GrpcWorkerCache : public WorkerCachePartial {
33 public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)34 explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
35 WorkerInterface* local_worker,
36 const string& local_target,
37 GrpcWorkerEnv* worker_env)
38 : local_target_(local_target),
39 local_worker_(local_worker),
40 channel_cache_(channel_cache),
41 worker_env_(worker_env),
42 next_round_robin_assignment_(0) {}
43
ListWorkers(std::vector<string> * workers) const44 void ListWorkers(std::vector<string>* workers) const override {
45 channel_cache_->ListWorkers(workers);
46 }
47
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const48 void ListWorkersInJob(const string& job_name,
49 std::vector<string>* workers) const override {
50 channel_cache_->ListWorkersInJob(job_name, workers);
51 }
52
GetOrCreateWorker(const string & target)53 WorkerInterface* GetOrCreateWorker(const string& target) override {
54 if (target == local_target_) {
55 return local_worker_;
56 } else {
57 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
58 if (!channel) {
59 return nullptr;
60 }
61 size_t index = AssignWorkerToThread(target);
62 return NewGrpcRemoteWorker(
63 channel, worker_env_->GetCompletionQueue(index),
64 worker_env_->GetThreadPool(), &logger_, target);
65 }
66 }
67
ReleaseWorker(const string & target,WorkerInterface * worker)68 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
69 if (target == local_target_) {
70 CHECK_EQ(worker, local_worker_)
71 << "Releasing a worker that was not returned by this WorkerCache";
72 } else {
73 WorkerCacheInterface::ReleaseWorker(target, worker);
74 }
75 }
76
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)77 Status GetEagerClientCache(
78 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
79 eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
80 return Status::OK();
81 }
82
SetLogging(bool v)83 void SetLogging(bool v) override { logger_.SetLogging(v); }
84
ClearLogs()85 void ClearLogs() override { logger_.ClearLogs(); }
86
RetrieveLogs(int64 step_id,StepStats * ss)87 bool RetrieveLogs(int64 step_id, StepStats* ss) override {
88 return logger_.RetrieveLogs(step_id, ss);
89 }
90
91 private:
AssignWorkerToThread(const string & target)92 size_t AssignWorkerToThread(const string& target) {
93 // Round-robin target assignment, but keeps the same target on the same
94 // polling thread always, as this is important for gRPC performance
95 mutex_lock lock(assignment_mu_);
96 auto it = target_assignments_.find(target);
97 if (it == target_assignments_.end()) {
98 it = target_assignments_
99 .insert(std::make_pair(target,
100 (next_round_robin_assignment_++) %
101 worker_env_->CompletionQueueSize()))
102 .first;
103 }
104 return it->second;
105 }
106
107 const string local_target_;
108 WorkerInterface* const local_worker_; // Not owned.
109 std::shared_ptr<GrpcChannelCache> channel_cache_;
110 WorkerCacheLogger logger_;
111 GrpcWorkerEnv* worker_env_; // Not owned
112
113 mutex assignment_mu_;
114 std::unordered_map<std::string, size_t> target_assignments_
115 TF_GUARDED_BY(assignment_mu_);
116 size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
117 };
118
119 } // namespace
120
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)121 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
122 : threadpool_(new thread::ThreadPool(
123 Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
124 /*low_latency_hint=*/false, /*allocator=*/nullptr)),
125 threads_(num_completion_queues) {}
126
~GrpcWorkerEnv()127 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
128
GrpcWorkerCacheThread()129 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
130 thread_.reset(Env::Default()->StartThread(
131 ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
132 void* tag;
133 bool ok;
134 while (completion_queue_.Next(&tag, &ok)) {
135 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
136 callback_tag->OnCompleted(ok);
137 }
138 }));
139 }
140
~GrpcWorkerCacheThread()141 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
142 completion_queue_.Shutdown();
143 thread_.reset();
144 }
145
CreateGrpcWorkerEnv()146 GrpcWorkerEnv* CreateGrpcWorkerEnv() {
147 int num_cpus = port::NumSchedulableCPUs();
148 int64 num_completion_queues;
149 Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
150 &num_completion_queues);
151 if (!status.ok()) {
152 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
153 }
154 int64 num_threads;
155 status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
156 &num_threads);
157 if (!status.ok()) {
158 LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
159 }
160 return new GrpcWorkerEnv(num_completion_queues, num_threads);
161 }
162
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env)163 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
164 GrpcWorkerEnv* worker_env) {
165 return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
166 worker_env);
167 }
168
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env,WorkerInterface * local_worker,const string & local_target)169 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
170 std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
171 WorkerInterface* local_worker, const string& local_target) {
172 return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
173 }
174
175 } // namespace tensorflow
176