• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17 
18 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
24 #include "tensorflow/core/distributed_runtime/worker_interface.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/util/env_var.h"
28 
29 namespace tensorflow {
30 
31 namespace {
32 
33 class GrpcWorkerCache : public WorkerCachePartial {
34  public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)35   explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
36                            WorkerInterface* local_worker,
37                            const string& local_target,
38                            GrpcWorkerEnv* worker_env)
39       : local_target_(local_target),
40         local_worker_(local_worker),
41         channel_cache_(channel_cache),
42         worker_env_(worker_env),
43         next_round_robin_assignment_(0) {}
44 
ListWorkers(std::vector<string> * workers) const45   void ListWorkers(std::vector<string>* workers) const override {
46     channel_cache_->ListWorkers(workers);
47   }
48 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const49   void ListWorkersInJob(const string& job_name,
50                         std::vector<string>* workers) const override {
51     channel_cache_->ListWorkersInJob(job_name, workers);
52   }
53 
GetOrCreateWorker(const string & target)54   WorkerInterface* GetOrCreateWorker(const string& target) override {
55     if (target == local_target_) {
56       return local_worker_;
57     } else {
58       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
59       if (!channel) {
60         return nullptr;
61       }
62       size_t index = AssignWorkerToThread(target);
63       return NewGrpcRemoteWorker(
64           channel, worker_env_->GetCompletionQueue(index),
65           worker_env_->GetThreadPool(), &logger_, target);
66     }
67   }
68 
ReleaseWorker(const string & target,WorkerInterface * worker)69   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70     if (target == local_target_) {
71       CHECK_EQ(worker, local_worker_)
72           << "Releasing a worker that was not returned by this WorkerCache";
73     } else {
74       WorkerCacheInterface::ReleaseWorker(target, worker);
75     }
76   }
77 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)78   Status GetEagerClientCache(
79       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
80     eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
81     return OkStatus();
82   }
83 
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)84   Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
85                                         coordination_client_cache) override {
86     coordination_client_cache->reset(
87         NewGrpcCoordinationClientCache(channel_cache_));
88     return OkStatus();
89   }
90 
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)91   Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
92                                         coordination_client_cache) override {
93 #if defined(PLATFORM_GOOGLE)
94     coordination_client_cache->reset(
95         NewGrpcCoordinationClientCache(channel_cache_));
96     return Status::OK();
97 #else
98     return errors::Unimplemented(
99         "Coordination service in open source is not yet implemented.");
100 #endif
101   }
102 
SetLogging(bool v)103   void SetLogging(bool v) override { logger_.SetLogging(v); }
104 
ClearLogs()105   void ClearLogs() override { logger_.ClearLogs(); }
106 
RetrieveLogs(int64_t step_id,StepStats * ss)107   bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
108     return logger_.RetrieveLogs(step_id, ss);
109   }
110 
111  private:
AssignWorkerToThread(const string & target)112   size_t AssignWorkerToThread(const string& target) {
113     // Round-robin target assignment, but keeps the same target on the same
114     // polling thread always, as this is important for gRPC performance
115     mutex_lock lock(assignment_mu_);
116     auto it = target_assignments_.find(target);
117     if (it == target_assignments_.end()) {
118       it = target_assignments_
119                .insert(std::make_pair(target,
120                                       (next_round_robin_assignment_++) %
121                                           worker_env_->CompletionQueueSize()))
122                .first;
123     }
124     return it->second;
125   }
126 
127   const string local_target_;
128   WorkerInterface* const local_worker_;  // Not owned.
129   std::shared_ptr<GrpcChannelCache> channel_cache_;
130   WorkerCacheLogger logger_;
131   GrpcWorkerEnv* worker_env_;  // Not owned
132 
133   mutex assignment_mu_;
134   std::unordered_map<std::string, size_t> target_assignments_
135       TF_GUARDED_BY(assignment_mu_);
136   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
137 };
138 
139 }  // namespace
140 
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)141 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
142     : threadpool_(new thread::ThreadPool(
143           Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
144           /*low_latency_hint=*/false, /*allocator=*/nullptr)),
145       threads_(num_completion_queues) {}
146 
~GrpcWorkerEnv()147 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
148 
GrpcWorkerCacheThread()149 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
150   thread_.reset(Env::Default()->StartThread(
151       ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
152         void* tag;
153         bool ok;
154         while (completion_queue_.Next(&tag, &ok)) {
155           GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
156           callback_tag->OnCompleted(ok);
157         }
158       }));
159 }
160 
~GrpcWorkerCacheThread()161 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
162   completion_queue_.Shutdown();
163   thread_.reset();
164 }
165 
CreateGrpcWorkerEnv()166 GrpcWorkerEnv* CreateGrpcWorkerEnv() {
167   int num_cpus = port::NumSchedulableCPUs();
168   int64_t num_completion_queues;
169   Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
170                                       &num_completion_queues);
171   if (!status.ok()) {
172     LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
173   }
174   int64_t num_threads;
175   status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
176                                &num_threads);
177   if (!status.ok()) {
178     LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
179   }
180   return new GrpcWorkerEnv(num_completion_queues, num_threads);
181 }
182 
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env)183 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
184                                          GrpcWorkerEnv* worker_env) {
185   return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
186                              worker_env);
187 }
188 
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env,WorkerInterface * local_worker,const string & local_target)189 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
190     std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
191     WorkerInterface* local_worker, const string& local_target) {
192   return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
193 }
194 
195 }  // namespace tensorflow
196