• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17 
18 #include <unordered_map>
19 
20 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
25 #include "tensorflow/core/distributed_runtime/worker_interface.h"
26 #include "tensorflow/core/platform/mutex.h"
27 
28 namespace tensorflow {
29 
30 namespace {
31 
32 // TODO(ncteisen): consider adding a config var or flag for this
33 static const size_t kGrpcWorkerCacheThreadCount = 8;
34 static const size_t kNumCallbackThreads = 10;
35 
36 class GrpcWorkerCache : public WorkerCachePartial {
37  public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)38   explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
39                            WorkerInterface* local_worker,
40                            const string& local_target,
41                            GrpcWorkerEnv* worker_env)
42       : local_target_(local_target),
43         local_worker_(local_worker),
44         channel_cache_(channel_cache),
45         worker_env_(worker_env),
46         next_round_robin_assignment_(0) {
47     if (worker_env_ == nullptr) {
48       worker_env_ptr_ = absl::make_unique<GrpcWorkerEnv>(
49           kGrpcWorkerCacheThreadCount, kNumCallbackThreads);
50       worker_env_ = worker_env_ptr_.get();
51     }
52   }
53 
ListWorkers(std::vector<string> * workers) const54   void ListWorkers(std::vector<string>* workers) const override {
55     channel_cache_->ListWorkers(workers);
56   }
57 
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const58   void ListWorkersInJob(const string& job_name,
59                         std::vector<string>* workers) const override {
60     channel_cache_->ListWorkersInJob(job_name, workers);
61   }
62 
GetOrCreateWorker(const string & target)63   WorkerInterface* GetOrCreateWorker(const string& target) override {
64     if (target == local_target_) {
65       return local_worker_;
66     } else {
67       SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
68       if (!channel) {
69         return nullptr;
70       }
71       size_t index = AssignWorkerToThread(target);
72       return NewGrpcRemoteWorker(channel,
73                                  worker_env_->GetCompletionQueue(index),
74                                  worker_env_->GetThreadPool(), &logger_);
75     }
76   }
77 
ReleaseWorker(const string & target,WorkerInterface * worker)78   void ReleaseWorker(const string& target, WorkerInterface* worker) override {
79     if (target == local_target_) {
80       CHECK_EQ(worker, local_worker_)
81           << "Releasing a worker that was not returned by this WorkerCache";
82     } else {
83       WorkerCacheInterface::ReleaseWorker(target, worker);
84     }
85   }
86 
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)87   Status GetEagerClientCache(
88       std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
89     eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
90     return Status::OK();
91   }
92 
SetLogging(bool v)93   void SetLogging(bool v) override { logger_.SetLogging(v); }
94 
ClearLogs()95   void ClearLogs() override { logger_.ClearLogs(); }
96 
RetrieveLogs(int64 step_id,StepStats * ss)97   bool RetrieveLogs(int64 step_id, StepStats* ss) override {
98     return logger_.RetrieveLogs(step_id, ss);
99   }
100 
101  private:
AssignWorkerToThread(const string & target)102   size_t AssignWorkerToThread(const string& target) {
103     // Round-robin target assignment, but keeps the same target on the same
104     // polling thread always, as this is important for gRPC performance
105     mutex_lock lock(assignment_mu_);
106     auto it = target_assignments_.find(target);
107     if (it == target_assignments_.end()) {
108       it = target_assignments_
109                .insert(std::make_pair(target,
110                                       (next_round_robin_assignment_++) %
111                                           worker_env_->CompletionQueueSize()))
112                .first;
113     }
114     return it->second;
115   }
116 
117   const string local_target_;
118   WorkerInterface* const local_worker_;  // Not owned.
119   std::shared_ptr<GrpcChannelCache> channel_cache_;
120   WorkerCacheLogger logger_;
121   GrpcWorkerEnv* worker_env_;  // Not owned, if worker_env_ptr_ is nullptr.
122   std::unique_ptr<GrpcWorkerEnv> worker_env_ptr_;
123 
124   mutex assignment_mu_;
125   std::unordered_map<std::string, size_t> target_assignments_
126       GUARDED_BY(assignment_mu_);
127   size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
128 };
129 
130 }  // namespace
131 
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)132 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
133     : threadpool_(new thread::ThreadPool(
134           Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
135           /*low_latency_hint=*/false, /*allocator=*/nullptr)),
136       threads_(num_completion_queues) {}
137 
~GrpcWorkerEnv()138 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
139 
GrpcWorkerCacheThread()140 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
141   thread_.reset(Env::Default()->StartThread(
142       ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
143         void* tag;
144         bool ok;
145         while (completion_queue_.Next(&tag, &ok)) {
146           GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
147           callback_tag->OnCompleted(ok);
148         }
149       }));
150 }
151 
~GrpcWorkerCacheThread()152 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
153   completion_queue_.Shutdown();
154   thread_.reset();
155 }
156 
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc)157 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc) {
158   return new GrpcWorkerCache(cc, nullptr, "", nullptr);
159 }
160 
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)161 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
162     std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
163     const string& local_target, GrpcWorkerEnv* worker_env) {
164   return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
165 }
166 
167 }  // namespace tensorflow
168