• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17 
18 #include "tensorflow/core/distributed_runtime/rpc/coordination/grpc_coordination_client.h"
19 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
20 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
22 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
24 #include "tensorflow/core/distributed_runtime/worker_interface.h"
25 #include "tensorflow/core/platform/cpu_info.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/core/util/env_var.h"
28 
29 namespace tensorflow {
30 
31 namespace {
32 
33 class GrpcWorkerCache : public WorkerCachePartial {
34  public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)35   explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
36                            WorkerInterface* local_worker,
37                            const string& local_target,
38                            GrpcWorkerEnv* worker_env)
39       : local_target_(local_target),
40         local_worker_(local_worker),
41         channel_cache_(channel_cache),
42         worker_env_(worker_env),
43         next_round_robin_assignment_(0) {}
44 
ListWorkers(std::vector<string> * workers) const45   void ListWorkers(std::vector<string>* workers) const override {
46     channel_cache_->ListWorkers(workers);
47   }
48 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const49   void ListWorkersInJob(const string& job_name,
50                         std::vector<string>* workers) const override {
51     channel_cache_->ListWorkersInJob(job_name, workers);
52   }
53 
GetOrCreateWorker(const string & target)54   WorkerInterface* GetOrCreateWorker(const string& target) override {
55     if (target == local_target_) {
56       return local_worker_;
57     } else {
58       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
59       if (!channel) {
60         return nullptr;
61       }
62       size_t index = AssignWorkerToThread(target);
63       return NewGrpcRemoteWorker(
64           channel, worker_env_->GetCompletionQueue(index),
65           worker_env_->GetThreadPool(), &logger_, target);
66     }
67   }
68 
ReleaseWorker(const string & target,WorkerInterface * worker)69   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
70     if (target == local_target_) {
71       CHECK_EQ(worker, local_worker_)
72           << "Releasing a worker that was not returned by this WorkerCache";
73     } else {
74       WorkerCacheInterface::ReleaseWorker(target, worker);
75     }
76   }
77 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)78   Status GetEagerClientCache(
79       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
80     eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
81     return Status::OK();
82   }
83 
GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache> * coordination_client_cache)84   Status GetCoordinationClientCache(std::unique_ptr<CoordinationClientCache>*
85                                         coordination_client_cache) override {
86 #if defined(PLATFORM_GOOGLE)
87     coordination_client_cache->reset(
88         NewGrpcCoordinationClientCache(channel_cache_));
89     return Status::OK();
90 #else
91     return errors::Unimplemented(
92         "Coordination service in open source is not yet implemented.");
93 #endif
94   }
95 
SetLogging(bool v)96   void SetLogging(bool v) override { logger_.SetLogging(v); }
97 
ClearLogs()98   void ClearLogs() override { logger_.ClearLogs(); }
99 
RetrieveLogs(int64_t step_id,StepStats * ss)100   bool RetrieveLogs(int64_t step_id, StepStats* ss) override {
101     return logger_.RetrieveLogs(step_id, ss);
102   }
103 
104  private:
AssignWorkerToThread(const string & target)105   size_t AssignWorkerToThread(const string& target) {
106     // Round-robin target assignment, but keeps the same target on the same
107     // polling thread always, as this is important for gRPC performance
108     mutex_lock lock(assignment_mu_);
109     auto it = target_assignments_.find(target);
110     if (it == target_assignments_.end()) {
111       it = target_assignments_
112                .insert(std::make_pair(target,
113                                       (next_round_robin_assignment_++) %
114                                           worker_env_->CompletionQueueSize()))
115                .first;
116     }
117     return it->second;
118   }
119 
120   const string local_target_;
121   WorkerInterface* const local_worker_;  // Not owned.
122   std::shared_ptr<GrpcChannelCache> channel_cache_;
123   WorkerCacheLogger logger_;
124   GrpcWorkerEnv* worker_env_;  // Not owned
125 
126   mutex assignment_mu_;
127   std::unordered_map<std::string, size_t> target_assignments_
128       TF_GUARDED_BY(assignment_mu_);
129   size_t next_round_robin_assignment_ TF_GUARDED_BY(assignment_mu_);
130 };
131 
132 }  // namespace
133 
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)134 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
135     : threadpool_(new thread::ThreadPool(
136           Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
137           /*low_latency_hint=*/false, /*allocator=*/nullptr)),
138       threads_(num_completion_queues) {}
139 
~GrpcWorkerEnv()140 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
141 
GrpcWorkerCacheThread()142 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
143   thread_.reset(Env::Default()->StartThread(
144       ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
145         void* tag;
146         bool ok;
147         while (completion_queue_.Next(&tag, &ok)) {
148           GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
149           callback_tag->OnCompleted(ok);
150         }
151       }));
152 }
153 
~GrpcWorkerCacheThread()154 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
155   completion_queue_.Shutdown();
156   thread_.reset();
157 }
158 
CreateGrpcWorkerEnv()159 GrpcWorkerEnv* CreateGrpcWorkerEnv() {
160   int num_cpus = port::NumSchedulableCPUs();
161   int64_t num_completion_queues;
162   Status status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_QUEUES", 64,
163                                       &num_completion_queues);
164   if (!status.ok()) {
165     LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_QUEUES: " << status;
166   }
167   int64_t num_threads;
168   status = ReadInt64FromEnvVar("TF_GRPC_WORKER_CACHE_THREADS", num_cpus,
169                                &num_threads);
170   if (!status.ok()) {
171     LOG(ERROR) << "Error parsing TF_GRPC_WORKER_CACHE_THREADS: " << status;
172   }
173   return new GrpcWorkerEnv(num_completion_queues, num_threads);
174 }
175 
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env)176 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc,
177                                          GrpcWorkerEnv* worker_env) {
178   return new GrpcWorkerCache(cc, /*local_worker=*/nullptr, /*local_target=*/"",
179                              worker_env);
180 }
181 
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,GrpcWorkerEnv * worker_env,WorkerInterface * local_worker,const string & local_target)182 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
183     std::shared_ptr<GrpcChannelCache> cc, GrpcWorkerEnv* worker_env,
184     WorkerInterface* local_worker, const string& local_target) {
185   return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
186 }
187 
188 }  // namespace tensorflow
189