• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17 
18 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
19 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
21 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
23 #include "tensorflow/core/distributed_runtime/worker_interface.h"
24 #include "tensorflow/core/platform/cpu_info.h"
25 #include "tensorflow/core/platform/mutex.h"
26 #include "tensorflow/core/util/env_var.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 
32 class GrpcWorkerCache : public WorkerCachePartial {
33  public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)34   explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
35                            WorkerInterface* local_worker,
36                            const string& local_target,
37                            GrpcWorkerEnv* worker_env)
38       : local_target_(local_target),
39         local_worker_(local_worker),
40         channel_cache_(channel_cache),
41         worker_env_(worker_env),
42         next_round_robin_assignment_(0) {}
43 
ListWorkers(std::vector<string> * workers) const44   void ListWorkers(std::vector<string>* workers) const override {
45     channel_cache_->ListWorkers(workers);
46   }
47 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const48   void ListWorkersInJob(const string& job_name,
49                         std::vector<string>* workers) const override {
50     channel_cache_->ListWorkersInJob(job_name, workers);
51   }
52 
GetOrCreateWorker(const string & target)53   WorkerInterface* GetOrCreateWorker(const string& target) override {
54     if (target == local_target_) {
55       return local_worker_;
56     } else {
57       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
58       if (!channel) {
59         return nullptr;
60       }
61       size_t index = AssignWorkerToThread(target);
62       return NewGrpcRemoteWorker(
63           channel, worker_env_->GetCompletionQueue(index),
64           worker_env_->GetThreadPool(), &logger_, target);
65     }
66   }
67 
ReleaseWorker(const string & target,WorkerInterface * worker)68   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
69     if (target == local_target_) {
70       CHECK_EQ(worker, local_worker_)
71           << "Releasing a worker that was not returned by this WorkerCache";
72     } else {
73       WorkerCacheInterface::ReleaseWorker(target, worker);
74     }
75   }
76 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)77   Status GetEagerClientCache(
78       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
79     eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
80     return Status::OK();
81   }
82 
SetLogging(bool v)83   void SetLogging(bool v) override { logger_.SetLogging(v); }
84 
ClearLogs()85   void ClearLogs() override { logger_.ClearLogs(); }
86 
RetrieveLogs(int64 step_id,StepStats * ss)87   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
88     return logger_.RetrieveLogs(step_id, ss);
89   }
90 
91  private:
AssignWorkerToThread(const string & target)92   size_t AssignWorkerToThread(const string& target) {
93     // Round-robin target assignment, but keeps the same target on the same
94     // polling thread always, as this is important for gRPC performance
95     mutex_lock lock(assignment_mu_);
96     auto it = target_assignments_.find(target);
97     if (it == target_assignments_.end()) {
98       it = target_assignments_
99                .insert(std::make_pair(target,
100                                       (next_round_robin_assignment_++) %
101                                           worker_env_->CompletionQueueSize()))
102                .first;
103     }
104     return it->second;
105   }
106 
107   const string local_target_;
108   WorkerInterface* const local_worker_;  // Not owned.
109   std::shared_ptr<GrpcChannelCache> channel_cache_;
110   WorkerCacheLogger logger_;
111   GrpcWorkerEnv* worker_env_;  // Not owned
112 
113   mutex assignment_mu_;
114   std::unordered_map<std::string, size_t> target_assignments_
115       TF_GUARDED_BY(assignment_mu_);
116   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
117 };
118 
119 }  // namespace
120 
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)121 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
122     : threadpool_(new thread::ThreadPool(
123           Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
124           /*low_latency_hint=*/false, /*allocator=*/nullptr)),
125       threads_(num_completion_queues) {}
126 
~GrpcWorkerEnv()127 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
128 
GrpcWorkerCacheThread()129 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
130   thread_.reset(Env::Default()->StartThread(
131       ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
132         void* tag;
133         bool ok;
134         while (completion_queue_.Next(&tag, &ok)) {
135           GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
136           callback_tag->OnCompleted(ok);
137         }
138       }));
139 }
140 
~GrpcWorkerCacheThread()141 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
142   completion_queue_.Shutdown();
143   thread_.reset();
144 }
145 
CreateGrpcWorkerEnv()146 GrpcWorkerEnv* CreateGrpcWorkerEnv() {
147   int num_cpus = port::NumSchedulableCPUs();
148   int64 num_completion_queues;
149   Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
150                                       &num_completion_queues);
151   if (!status.ok()) {
152     LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
153   }
154   int64 num_threads;
155   status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
156                                &num_threads);
157   if (!status.ok()) {
158     LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
159   }
160   return new GrpcWorkerEnv(num_completion_queues, num_threads);
161 }
162 
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env)163 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
164                                          GrpcWorkerEnv* worker_env) {
165   return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
166                              worker_env);
167 }
168 
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env,WorkerInterface * local_worker,const string & local_target)169 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
170     std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
171     WorkerInterface* local_worker, const string& local_target) {
172   return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
173 }
174 
175 }  // namespace tensorflow
176