1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17
18 #include <limits>
19 #include <map>
20 #include <unordered_map>
21
22 #include "grpcpp/create_channel.h"
23
24 #include "tensorflow/core/lib/core/errors.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/lib/gtl/map_util.h"
27 #include "tensorflow/core/lib/strings/numbers.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/lib/strings/strcat.h"
30 #include "tensorflow/core/platform/logging.h"
31 #include "tensorflow/core/platform/macros.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/platform/types.h"
35 #include "tensorflow/core/util/device_name_utils.h"
36
37 namespace tensorflow {
38
39 namespace {
40
MakeAddress(const string & job,int task)41 string MakeAddress(const string& job, int task) {
42 return strings::StrCat("/job:", job, "/replica:0/task:", task);
43 }
44
45 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)46 Status ValidateHostPortPair(const string& host_port) {
47 uint32 port;
48 auto colon_index = host_port.find_last_of(':');
49 if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
50 host_port.substr(0, colon_index).find("/") != string::npos) {
51 return errors::InvalidArgument("Could not interpret \"", host_port,
52 "\" as a host-port pair.");
53 }
54 return Status::OK();
55 }
56
57 } // namespace
58
GetChannelArguments(const RPCOptions * rpc_options)59 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
60 // TODO(mrry): Implement secure channels.
61 ::grpc::ChannelArguments args;
62 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
63 // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
64 // on connection failure, which makes our tests time out.
65 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
66 if (rpc_options != nullptr) {
67 if (rpc_options->compression_algorithm() == "deflate") {
68 args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
69 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
70 rpc_options->compression_level());
71 VLOG(5) << "Setting GRPC compression : algo='"
72 << rpc_options->compression_algorithm()
73 << "' level=" << rpc_options->compression_level();
74 } else if (rpc_options->compression_algorithm() == "gzip") {
75 args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
76 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
77 rpc_options->compression_level());
78 VLOG(5) << "Setting GRPC compression : algo='"
79 << rpc_options->compression_algorithm()
80 << "' level=" << rpc_options->compression_level();
81 } else if (!rpc_options->compression_algorithm().empty()) {
82 LOG(ERROR) << "Invalid compression algorithm: "
83 << rpc_options->compression_algorithm();
84 }
85 }
86 return args;
87 }
88
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)89 Status NewHostPortGrpcChannel(const string& target,
90 const RPCOptions* rpc_options,
91 SharedGrpcChannelPtr* channel_pointer) {
92 // Minimally ensure that the target is valid
93 TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
94
95 ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
96 *channel_pointer = ::grpc::CreateCustomChannel(
97 "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
98 return Status::OK();
99 }
100
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)101 ChannelCreationFunction ConvertToChannelCreationFunction(
102 const std::function<Status(string, const RPCOptions*,
103 SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
104 return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
105 SharedGrpcChannelPtr channel_ptr;
106 if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
107 .ok()) {
108 return channel_ptr;
109 } else {
110 return nullptr;
111 }
112 };
113 }
114
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)115 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
116 const std::vector<string>& host_ports) {
117 std::map<int, string> host_ports_map;
118 for (size_t i = 0; i < host_ports.size(); ++i) {
119 host_ports_map[i] = host_ports[i];
120 }
121 return AddHostPortsJob(job_id, host_ports_map);
122 }
123
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)124 Status GrpcChannelSpec::AddHostPortsJob(
125 const string& job_id, const std::map<int, string>& host_ports) {
126 if (!job_ids_.insert(job_id).second) {
127 return errors::InvalidArgument(
128 "Duplicate job ID in cluster specification: ", job_id);
129 }
130 for (const auto& id_host_port : host_ports) {
131 TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
132 }
133 host_ports_jobs_.emplace_back(job_id, host_ports);
134 return Status::OK();
135 }
136
137 namespace {
138
139 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
140 class CachingGrpcChannelCache : public GrpcChannelCache {
141 public:
CachingGrpcChannelCache()142 CachingGrpcChannelCache() {}
143
~CachingGrpcChannelCache()144 ~CachingGrpcChannelCache() override {}
145
FindWorkerChannel(const string & target)146 SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
147 SharedGrpcChannelPtr ch = nullptr;
148 {
149 mutex_lock l(mu_); // could use reader lock
150 ch = gtl::FindPtrOrNull(channels_, target);
151 if (ch) {
152 return ch;
153 }
154 }
155 ch = FindChannelOnce(target);
156 if (ch) {
157 mutex_lock l(mu_);
158 channels_.insert({target, ch});
159 }
160 return ch;
161 }
162
163 protected:
164 // Find the ClientChannel for "target". Only called when no channel was
165 // found in the channels_ cache for "target". A non nullptr result will be
166 // cached in channels_.
167 virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
168
169 private:
170 // TODO(zhifengc): Eviction when the map becomes too big.
171 mutex mu_;
172 std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
173 };
174
175 // A ChannelCache that is the union of multiple ChannelCaches.
176 // Takes ownership of the caches passed to the constructor.
177 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
178 public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches)179 explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
180 : CachingGrpcChannelCache(), caches_(caches) {}
181
~MultiGrpcChannelCache()182 ~MultiGrpcChannelCache() override {
183 for (GrpcChannelCache* cache : caches_) {
184 delete cache;
185 }
186 }
187
ListWorkers(std::vector<string> * workers)188 void ListWorkers(std::vector<string>* workers) override {
189 for (GrpcChannelCache* cache : caches_) {
190 cache->ListWorkers(workers);
191 }
192 }
193
ListWorkersInJob(const string & job_name,std::vector<string> * workers)194 void ListWorkersInJob(const string& job_name,
195 std::vector<string>* workers) override {
196 for (GrpcChannelCache* cache : caches_) {
197 cache->ListWorkersInJob(job_name, workers);
198 }
199 }
200
TranslateTask(const string & target)201 string TranslateTask(const string& target) override {
202 mutex_lock l(mu_); // could use reader lock
203 GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
204 if (cache == nullptr) {
205 for (GrpcChannelCache* c : caches_) {
206 string r = c->TranslateTask(target);
207 if (!r.empty()) {
208 target_caches_.insert({target, c});
209 cache = c;
210 break;
211 }
212 }
213 }
214 CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
215 << target;
216 return cache->TranslateTask(target);
217 }
218
219 protected:
FindChannelOnce(const string & target)220 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
221 for (GrpcChannelCache* cache : caches_) {
222 SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
223 if (ch) {
224 mutex_lock l(mu_);
225 target_caches_.insert({target, cache});
226 return ch;
227 }
228 }
229 return nullptr;
230 }
231
232 private:
233 // List of channels used by this MultiGrpcChannelCache.
234 const std::vector<GrpcChannelCache*> caches_;
235
236 mutex mu_;
237 // Cache of channels keyed by the target they are handling.
238 // The same GrpcChannelCache can appear multiple times in the cache.
239 std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
240 };
241
242 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
243 public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func)244 SparseGrpcChannelCache(const string& job_id,
245 const std::map<int, string>& host_ports,
246 ChannelCreationFunction channel_func)
247 : job_id_(job_id),
248 host_ports_(host_ports),
249 channel_func_(std::move(channel_func)) {
250 LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
251 }
~SparseGrpcChannelCache()252 ~SparseGrpcChannelCache() override {}
253
ListWorkers(std::vector<string> * workers)254 void ListWorkers(std::vector<string>* workers) override {
255 workers->reserve(workers->size() + host_ports_.size());
256 for (const auto& id_host_port : host_ports_) {
257 workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
258 }
259 }
260
ListWorkersInJob(const string & job_name,std::vector<string> * workers)261 void ListWorkersInJob(const string& job_name,
262 std::vector<string>* workers) override {
263 if (job_name == job_id_) {
264 ListWorkers(workers);
265 }
266 }
267
TranslateTask(const string & target)268 string TranslateTask(const string& target) override {
269 DeviceNameUtils::ParsedName parsed;
270 if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
271 LOG(WARNING) << "Invalid target: " << target;
272 return "";
273 }
274
275 if (!parsed.has_job || parsed.job != job_id_) {
276 return "";
277 }
278 if (!parsed.has_replica || parsed.replica != 0) {
279 LOG(WARNING) << "Replica ID must be 0 in target: " << target;
280 return "";
281 }
282 int32 task = parsed.has_task ? parsed.task : -1;
283 auto iter = host_ports_.find(task);
284 if (iter == host_ports_.end()) {
285 LOG(WARNING) << "Task " << task << " was not defined in sparse job "
286 << job_id_ << ": " << target;
287 return "";
288 }
289 return iter->second;
290 }
291
292 protected:
FindChannelOnce(const string & target)293 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
294 const string host_port = TranslateTask(target);
295 if (host_port.empty()) {
296 return nullptr;
297 }
298 return channel_func_(host_port);
299 }
300
301 private:
ToString()302 string ToString() {
303 std::vector<string> task_strings;
304 task_strings.reserve(host_ports_.size());
305 for (const auto& id_host_port : host_ports_) {
306 task_strings.emplace_back(
307 strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
308 }
309 return strings::StrCat(job_id_, " -> {", str_util::Join(task_strings, ", "),
310 "}");
311 }
312
313 const string job_id_;
314 const std::map<int, string> host_ports_;
315 const ChannelCreationFunction channel_func_;
316 TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
317 };
318
319 } // namespace
320
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func)321 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
322 ChannelCreationFunction channel_func) {
323 const int num_jobs = spec.host_ports_jobs().size();
324 if (!num_jobs) {
325 LOG(ERROR) << "Empty channel spec.";
326 return nullptr;
327 }
328 std::vector<GrpcChannelCache*> caches;
329 caches.reserve(num_jobs);
330 for (auto& job : spec.host_ports_jobs()) {
331 caches.push_back(
332 new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
333 }
334 return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
335 }
336
337 } // end namespace tensorflow
338