1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
17
18 #include <unordered_map>
19
20 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
21 #include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
24 #include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
25 #include "tensorflow/core/distributed_runtime/worker_interface.h"
26 #include "tensorflow/core/platform/mutex.h"
27
28 namespace tensorflow {
29
30 namespace {
31
32 // TODO(ncteisen): consider adding a config var or flag for this
33 static const size_t kGrpcWorkerCacheThreadCount = 8;
34 static const size_t kNumCallbackThreads = 10;
35
36 class GrpcWorkerCache : public WorkerCachePartial {
37 public:
GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)38 explicit GrpcWorkerCache(std::shared_ptr<GrpcChannelCache> channel_cache,
39 WorkerInterface* local_worker,
40 const string& local_target,
41 GrpcWorkerEnv* worker_env)
42 : local_target_(local_target),
43 local_worker_(local_worker),
44 channel_cache_(channel_cache),
45 worker_env_(worker_env),
46 next_round_robin_assignment_(0) {
47 if (worker_env_ == nullptr) {
48 worker_env_ptr_ = absl::make_unique<GrpcWorkerEnv>(
49 kGrpcWorkerCacheThreadCount, kNumCallbackThreads);
50 worker_env_ = worker_env_ptr_.get();
51 }
52 }
53
ListWorkers(std::vector<string> * workers) const54 void ListWorkers(std::vector<string>* workers) const override {
55 channel_cache_->ListWorkers(workers);
56 }
57
ListWorkersInJob(const string & job_name,std::vector<string> * workers) const58 void ListWorkersInJob(const string& job_name,
59 std::vector<string>* workers) const override {
60 channel_cache_->ListWorkersInJob(job_name, workers);
61 }
62
GetOrCreateWorker(const string & target)63 WorkerInterface* GetOrCreateWorker(const string& target) override {
64 if (target == local_target_) {
65 return local_worker_;
66 } else {
67 SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
68 if (!channel) {
69 return nullptr;
70 }
71 size_t index = AssignWorkerToThread(target);
72 return NewGrpcRemoteWorker(channel,
73 worker_env_->GetCompletionQueue(index),
74 worker_env_->GetThreadPool(), &logger_);
75 }
76 }
77
ReleaseWorker(const string & target,WorkerInterface * worker)78 void ReleaseWorker(const string& target, WorkerInterface* worker) override {
79 if (target == local_target_) {
80 CHECK_EQ(worker, local_worker_)
81 << "Releasing a worker that was not returned by this WorkerCache";
82 } else {
83 WorkerCacheInterface::ReleaseWorker(target, worker);
84 }
85 }
86
GetEagerClientCache(std::unique_ptr<eager::EagerClientCache> * eager_client_cache)87 Status GetEagerClientCache(
88 std::unique_ptr<eager::EagerClientCache>* eager_client_cache) override {
89 eager_client_cache->reset(eager::NewGrpcEagerClientCache(channel_cache_));
90 return Status::OK();
91 }
92
SetLogging(bool v)93 void SetLogging(bool v) override { logger_.SetLogging(v); }
94
ClearLogs()95 void ClearLogs() override { logger_.ClearLogs(); }
96
RetrieveLogs(int64 step_id,StepStats * ss)97 bool RetrieveLogs(int64 step_id, StepStats* ss) override {
98 return logger_.RetrieveLogs(step_id, ss);
99 }
100
101 private:
AssignWorkerToThread(const string & target)102 size_t AssignWorkerToThread(const string& target) {
103 // Round-robin target assignment, but keeps the same target on the same
104 // polling thread always, as this is important for gRPC performance
105 mutex_lock lock(assignment_mu_);
106 auto it = target_assignments_.find(target);
107 if (it == target_assignments_.end()) {
108 it = target_assignments_
109 .insert(std::make_pair(target,
110 (next_round_robin_assignment_++) %
111 worker_env_->CompletionQueueSize()))
112 .first;
113 }
114 return it->second;
115 }
116
117 const string local_target_;
118 WorkerInterface* const local_worker_; // Not owned.
119 std::shared_ptr<GrpcChannelCache> channel_cache_;
120 WorkerCacheLogger logger_;
121 GrpcWorkerEnv* worker_env_; // Not owned, if worker_env_ptr_ is nullptr.
122 std::unique_ptr<GrpcWorkerEnv> worker_env_ptr_;
123
124 mutex assignment_mu_;
125 std::unordered_map<std::string, size_t> target_assignments_
126 GUARDED_BY(assignment_mu_);
127 size_t next_round_robin_assignment_ GUARDED_BY(assignment_mu_);
128 };
129
130 } // namespace
131
GrpcWorkerEnv(size_t num_completion_queues,size_t num_threads)132 GrpcWorkerEnv::GrpcWorkerEnv(size_t num_completion_queues, size_t num_threads)
133 : threadpool_(new thread::ThreadPool(
134 Env::Default(), ThreadOptions(), "GrpcWorkerEnvQueues", num_threads,
135 /*low_latency_hint=*/false, /*allocator=*/nullptr)),
136 threads_(num_completion_queues) {}
137
~GrpcWorkerEnv()138 GrpcWorkerEnv::~GrpcWorkerEnv() { threads_.clear(); }
139
GrpcWorkerCacheThread()140 GrpcWorkerEnv::GrpcWorkerCacheThread::GrpcWorkerCacheThread() {
141 thread_.reset(Env::Default()->StartThread(
142 ThreadOptions(), "GrpcWorkerEnvPool", [this]() {
143 void* tag;
144 bool ok;
145 while (completion_queue_.Next(&tag, &ok)) {
146 GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
147 callback_tag->OnCompleted(ok);
148 }
149 }));
150 }
151
~GrpcWorkerCacheThread()152 GrpcWorkerEnv::GrpcWorkerCacheThread::~GrpcWorkerCacheThread() {
153 completion_queue_.Shutdown();
154 thread_.reset();
155 }
156
NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc)157 WorkerCacheInterface* NewGrpcWorkerCache(std::shared_ptr<GrpcChannelCache> cc) {
158 return new GrpcWorkerCache(cc, nullptr, "", nullptr);
159 }
160
NewGrpcWorkerCacheWithLocalWorker(std::shared_ptr<GrpcChannelCache> cc,WorkerInterface * local_worker,const string & local_target,GrpcWorkerEnv * worker_env)161 WorkerCacheInterface* NewGrpcWorkerCacheWithLocalWorker(
162 std::shared_ptr<GrpcChannelCache> cc, WorkerInterface* local_worker,
163 const string& local_target, GrpcWorkerEnv* worker_env) {
164 return new GrpcWorkerCache(cc, local_worker, local_target, worker_env);
165 }
166
167 } // namespace tensorflow
168