1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17
18 #include <cstdlib>
19 #include <limits>
20 #include <map>
21 #include <unordered_map>
22
23 #include "grpcpp/create_channel.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace tensorflow {
40
41 namespace {
42
MakeAddress(const string & job,int task)43 string MakeAddress(const string& job, int task) {
44 return strings::StrCat("/job:", job, "/replica:0/task:", task);
45 }
46
47 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)48 Status ValidateHostPortPair(const string& host_port) {
49 string bns_prefix = "/bns/";
50 if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
51 return Status::OK();
52 }
53 uint32 port;
54 auto colon_index = host_port.find_last_of(':');
55 if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
56 host_port.substr(0, colon_index).find('/') != string::npos) {
57 return errors::InvalidArgument("Could not interpret \"", host_port,
58 "\" as a host-port pair.");
59 }
60 return Status::OK();
61 }
62
CreateDefaultChannelArguments()63 ::grpc::ChannelArguments* CreateDefaultChannelArguments() {
64 ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
65 const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
66 if (env != nullptr) {
67 for (auto& grpc_option : absl::StrSplit(env, ',')) {
68 std::vector<string> name_value = absl::StrSplit(grpc_option, '=');
69 if (name_value.size() != 2) {
70 LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
71 continue;
72 }
73 VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
74 << name_value[1] << "'";
75 if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
76 string ue_value = name_value[1].substr(1, name_value[1].size() - 2);
77 string value;
78 string error;
79 if (!absl::CUnescape(ue_value, &value, &error)) {
80 LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
81 << ": " << error;
82 } else {
83 args->SetString(name_value[0], value);
84 }
85 } else {
86 int64 value;
87 if (strings::safe_strto64(name_value[1], &value)) {
88 args->SetInt(name_value[0], value);
89 } else {
90 LOG(ERROR) << "Invalid integer value: " << grpc_option;
91 }
92 }
93 }
94 }
95 return args;
96 }
97
GetDefaultChannelArguments()98 const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
99 static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
100 return args;
101 }
102
103 } // namespace
104
GetChannelArguments(const RPCOptions * rpc_options)105 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
106 // TODO(mrry): Implement secure channels.
107 ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
108 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
109 // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
110 // on connection failure, which makes our tests time out.
111 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
112 if (rpc_options != nullptr) {
113 if (rpc_options->compression_algorithm() == "deflate") {
114 args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
115 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
116 rpc_options->compression_level());
117 VLOG(5) << "Setting GRPC compression : algo='"
118 << rpc_options->compression_algorithm()
119 << "' level=" << rpc_options->compression_level();
120 } else if (rpc_options->compression_algorithm() == "gzip") {
121 args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
122 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
123 rpc_options->compression_level());
124 VLOG(5) << "Setting GRPC compression : algo='"
125 << rpc_options->compression_algorithm()
126 << "' level=" << rpc_options->compression_level();
127 } else if (!rpc_options->compression_algorithm().empty()) {
128 LOG(ERROR) << "Invalid compression algorithm: "
129 << rpc_options->compression_algorithm();
130 }
131 if (rpc_options->disable_session_connection_sharing()) {
132 VLOG(5) << "Disabling TCP connection sharing";
133 args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
134 }
135 }
136 return args;
137 }
138
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)139 Status NewHostPortGrpcChannel(const string& target,
140 const RPCOptions* rpc_options,
141 SharedGrpcChannelPtr* channel_pointer) {
142 // Minimally ensure that the target is valid
143 TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
144
145 ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
146 *channel_pointer = ::grpc::CreateCustomChannel(
147 "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
148 return Status::OK();
149 }
150
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)151 ChannelCreationFunction ConvertToChannelCreationFunction(
152 const std::function<Status(string, const RPCOptions*,
153 SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
154 return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
155 SharedGrpcChannelPtr channel_ptr;
156 if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
157 .ok()) {
158 return channel_ptr;
159 } else {
160 return nullptr;
161 }
162 };
163 }
164
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)165 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
166 const std::vector<string>& host_ports) {
167 std::map<int, string> host_ports_map;
168 for (size_t i = 0; i < host_ports.size(); ++i) {
169 host_ports_map[i] = host_ports[i];
170 }
171 return AddHostPortsJob(job_id, host_ports_map);
172 }
173
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)174 Status GrpcChannelSpec::AddHostPortsJob(
175 const string& job_id, const std::map<int, string>& host_ports) {
176 if (!job_ids_.insert(job_id).second) {
177 return errors::InvalidArgument(
178 "Duplicate job ID in cluster specification: ", job_id);
179 }
180 for (const auto& id_host_port : host_ports) {
181 TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
182 }
183 host_ports_jobs_.emplace_back(job_id, host_ports);
184 return Status::OK();
185 }
186
187 namespace {
188
189 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
190 class CachingGrpcChannelCache : public GrpcChannelCache {
191 public:
CachingGrpcChannelCache()192 CachingGrpcChannelCache() {}
193
~CachingGrpcChannelCache()194 ~CachingGrpcChannelCache() override {}
195
FindWorkerChannel(const string & target)196 SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
197 SharedGrpcChannelPtr ch = nullptr;
198 {
199 mutex_lock l(mu_); // could use reader lock
200 ch = gtl::FindPtrOrNull(channels_, target);
201 if (ch) {
202 return ch;
203 }
204 }
205 ch = FindChannelOnce(target);
206 if (ch) {
207 mutex_lock l(mu_);
208 channels_.insert({target, ch});
209 }
210 return ch;
211 }
212
213 protected:
214 // Find the ClientChannel for "target". Only called when no channel was
215 // found in the channels_ cache for "target". A non nullptr result will be
216 // cached in channels_.
217 virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
218
219 private:
220 // TODO(zhifengc): Eviction when the map becomes too big.
221 mutex mu_;
222 std::unordered_map<string, SharedGrpcChannelPtr> channels_ TF_GUARDED_BY(mu_);
223 };
224
225 // A ChannelCache that is the union of multiple ChannelCaches.
226 // Takes ownership of the caches passed to the constructor.
227 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
228 public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches)229 explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
230 : CachingGrpcChannelCache(), caches_(caches) {}
231
~MultiGrpcChannelCache()232 ~MultiGrpcChannelCache() override {
233 for (GrpcChannelCache* cache : caches_) {
234 delete cache;
235 }
236 }
237
ListWorkers(std::vector<string> * workers)238 void ListWorkers(std::vector<string>* workers) override {
239 for (GrpcChannelCache* cache : caches_) {
240 cache->ListWorkers(workers);
241 }
242 }
243
ListWorkersInJob(const string & job_name,std::vector<string> * workers)244 void ListWorkersInJob(const string& job_name,
245 std::vector<string>* workers) override {
246 for (GrpcChannelCache* cache : caches_) {
247 cache->ListWorkersInJob(job_name, workers);
248 }
249 }
250
TranslateTask(const string & target)251 string TranslateTask(const string& target) override {
252 mutex_lock l(mu_); // could use reader lock
253 GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
254 if (cache == nullptr) {
255 for (GrpcChannelCache* c : caches_) {
256 string r = c->TranslateTask(target);
257 if (!r.empty()) {
258 target_caches_.insert({target, c});
259 cache = c;
260 break;
261 }
262 }
263 }
264 CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
265 << target;
266 return cache->TranslateTask(target);
267 }
268
269 protected:
FindChannelOnce(const string & target)270 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
271 for (GrpcChannelCache* cache : caches_) {
272 SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
273 if (ch) {
274 mutex_lock l(mu_);
275 target_caches_.insert({target, cache});
276 return ch;
277 }
278 }
279 return nullptr;
280 }
281
282 private:
283 // List of channels used by this MultiGrpcChannelCache.
284 const std::vector<GrpcChannelCache*> caches_;
285
286 mutex mu_;
287 // Cache of channels keyed by the target they are handling.
288 // The same GrpcChannelCache can appear multiple times in the cache.
289 std::unordered_map<string, GrpcChannelCache*> target_caches_
290 TF_GUARDED_BY(mu_);
291 };
292
293 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
294 public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func)295 SparseGrpcChannelCache(const string& job_id,
296 const std::map<int, string>& host_ports,
297 ChannelCreationFunction channel_func)
298 : job_id_(job_id),
299 host_ports_(host_ports),
300 channel_func_(std::move(channel_func)) {
301 LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
302 }
~SparseGrpcChannelCache()303 ~SparseGrpcChannelCache() override {}
304
ListWorkers(std::vector<string> * workers)305 void ListWorkers(std::vector<string>* workers) override {
306 workers->reserve(workers->size() + host_ports_.size());
307 for (const auto& id_host_port : host_ports_) {
308 workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
309 }
310 }
311
ListWorkersInJob(const string & job_name,std::vector<string> * workers)312 void ListWorkersInJob(const string& job_name,
313 std::vector<string>* workers) override {
314 if (job_name == job_id_) {
315 ListWorkers(workers);
316 }
317 }
318
TranslateTask(const string & target)319 string TranslateTask(const string& target) override {
320 DeviceNameUtils::ParsedName parsed;
321 if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
322 LOG(WARNING) << "Invalid target: " << target;
323 return "";
324 }
325
326 if (!parsed.has_job || parsed.job != job_id_) {
327 return "";
328 }
329 if (!parsed.has_replica || parsed.replica != 0) {
330 LOG(WARNING) << "Replica ID must be 0 in target: " << target;
331 return "";
332 }
333 int32 task = parsed.has_task ? parsed.task : -1;
334 auto iter = host_ports_.find(task);
335 if (iter == host_ports_.end()) {
336 LOG(WARNING) << "Task " << task << " was not defined in sparse job "
337 << job_id_ << ": " << target;
338 return "";
339 }
340 return iter->second;
341 }
342
343 protected:
FindChannelOnce(const string & target)344 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
345 const string host_port = TranslateTask(target);
346 if (host_port.empty()) {
347 return nullptr;
348 }
349 return channel_func_(host_port);
350 }
351
352 private:
ToString()353 string ToString() {
354 std::vector<string> task_strings;
355 task_strings.reserve(host_ports_.size());
356 for (const auto& id_host_port : host_ports_) {
357 task_strings.emplace_back(
358 strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
359 }
360 return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
361 "}");
362 }
363
364 const string job_id_;
365 const std::map<int, string> host_ports_;
366 const ChannelCreationFunction channel_func_;
367 TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
368 };
369
370 } // namespace
371
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func)372 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
373 ChannelCreationFunction channel_func) {
374 const int num_jobs = spec.host_ports_jobs().size();
375 if (!num_jobs) {
376 LOG(ERROR) << "Empty channel spec.";
377 return nullptr;
378 }
379 std::vector<GrpcChannelCache*> caches;
380 caches.reserve(num_jobs);
381 for (auto& job : spec.host_ports_jobs()) {
382 caches.push_back(
383 new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
384 }
385 return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
386 }
387
388 } // end namespace tensorflow
389