1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17
18 #include <cstdlib>
19 #include <limits>
20 #include <map>
21 #include <unordered_map>
22
23 #include "grpcpp/create_channel.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38
39 namespace tensorflow {
40
41 namespace {
42
MakeAddress(const string & job,int task)43 string MakeAddress(const string& job, int task) {
44 return strings::StrCat("/job:", job, "/replica:0/task:", task);
45 }
46
47 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)48 Status ValidateHostPortPair(const string& host_port) {
49 string bns_prefix = "/bns/";
50 if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
51 return Status::OK();
52 }
53 uint32 port;
54 auto colon_index = host_port.find_last_of(':');
55 if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
56 host_port.substr(0, colon_index).find("/") != string::npos) {
57 return errors::InvalidArgument("Could not interpret \"", host_port,
58 "\" as a host-port pair.");
59 }
60 return Status::OK();
61 }
62
CreateDefaultChannelArguments()63 ::grpc::ChannelArguments* CreateDefaultChannelArguments() {
64 ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
65 const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
66 if (env != nullptr) {
67 for (auto& grpc_option : absl::StrSplit(env, ',')) {
68 std::vector<string> name_value = absl::StrSplit(grpc_option, '=');
69 if (name_value.size() != 2) {
70 LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
71 continue;
72 }
73 VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
74 << name_value[1] << "'";
75 if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
76 string ue_value = name_value[1].substr(1, name_value[1].size() - 2);
77 string value;
78 string error;
79 if (!absl::CUnescape(ue_value, &value, &error)) {
80 LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
81 << ": " << error;
82 } else {
83 args->SetString(name_value[0], value);
84 }
85 } else {
86 int64 value;
87 if (strings::safe_strto64(name_value[1], &value)) {
88 args->SetInt(name_value[0], value);
89 } else {
90 LOG(ERROR) << "Invalid integer value: " << grpc_option;
91 }
92 }
93 }
94 }
95 return args;
96 }
97
GetDefaultChannelArguments()98 const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
99 static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
100 return args;
101 }
102
103 } // namespace
104
GetChannelArguments(const RPCOptions * rpc_options)105 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
106 // TODO(mrry): Implement secure channels.
107 ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
108 args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
109 // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
110 // on connection failure, which makes our tests time out.
111 args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
112 if (rpc_options != nullptr) {
113 if (rpc_options->compression_algorithm() == "deflate") {
114 args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
115 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
116 rpc_options->compression_level());
117 VLOG(5) << "Setting GRPC compression : algo='"
118 << rpc_options->compression_algorithm()
119 << "' level=" << rpc_options->compression_level();
120 } else if (rpc_options->compression_algorithm() == "gzip") {
121 args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
122 args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
123 rpc_options->compression_level());
124 VLOG(5) << "Setting GRPC compression : algo='"
125 << rpc_options->compression_algorithm()
126 << "' level=" << rpc_options->compression_level();
127 } else if (!rpc_options->compression_algorithm().empty()) {
128 LOG(ERROR) << "Invalid compression algorithm: "
129 << rpc_options->compression_algorithm();
130 }
131 if (rpc_options->disable_session_connection_sharing()) {
132 VLOG(5) << "Disabling TCP connection sharing";
133 args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
134 }
135 }
136 return args;
137 }
138
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)139 Status NewHostPortGrpcChannel(const string& target,
140 const RPCOptions* rpc_options,
141 SharedGrpcChannelPtr* channel_pointer) {
142 // Minimally ensure that the target is valid
143 TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
144
145 ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
146 *channel_pointer = ::grpc::CreateCustomChannel(
147 "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
148 return Status::OK();
149 }
150
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)151 ChannelCreationFunction ConvertToChannelCreationFunction(
152 const std::function<Status(string, const RPCOptions*,
153 SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
154 return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
155 SharedGrpcChannelPtr channel_ptr;
156 if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
157 .ok()) {
158 return channel_ptr;
159 } else {
160 return nullptr;
161 }
162 };
163 }
164
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)165 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
166 const std::vector<string>& host_ports) {
167 std::map<int, string> host_ports_map;
168 for (size_t i = 0; i < host_ports.size(); ++i) {
169 host_ports_map[i] = host_ports[i];
170 }
171 return AddHostPortsJob(job_id, host_ports_map);
172 }
173
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)174 Status GrpcChannelSpec::AddHostPortsJob(
175 const string& job_id, const std::map<int, string>& host_ports) {
176 if (!job_ids_.insert(job_id).second) {
177 return errors::InvalidArgument(
178 "Duplicate job ID in cluster specification: ", job_id);
179 }
180 for (const auto& id_host_port : host_ports) {
181 TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
182 }
183 host_ports_jobs_.emplace_back(job_id, host_ports);
184 return Status::OK();
185 }
186
187 namespace {
188
189 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
190 class CachingGrpcChannelCache : public GrpcChannelCache {
191 public:
CachingGrpcChannelCache()192 CachingGrpcChannelCache() {}
193
~CachingGrpcChannelCache()194 ~CachingGrpcChannelCache() override {}
195
FindWorkerChannel(const string & target)196 SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
197 SharedGrpcChannelPtr ch = nullptr;
198 {
199 mutex_lock l(mu_); // could use reader lock
200 ch = gtl::FindPtrOrNull(channels_, target);
201 if (ch) {
202 return ch;
203 }
204 }
205 ch = FindChannelOnce(target);
206 if (ch) {
207 mutex_lock l(mu_);
208 channels_.insert({target, ch});
209 }
210 return ch;
211 }
212
213 protected:
214 // Find the ClientChannel for "target". Only called when no channel was
215 // found in the channels_ cache for "target". A non nullptr result will be
216 // cached in channels_.
217 virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
218
219 private:
220 // TODO(zhifengc): Eviction when the map becomes too big.
221 mutex mu_;
222 std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
223 };
224
225 // A ChannelCache that is the union of multiple ChannelCaches.
226 // Takes ownership of the caches passed to the constructor.
227 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
228 public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches)229 explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
230 : CachingGrpcChannelCache(), caches_(caches) {}
231
~MultiGrpcChannelCache()232 ~MultiGrpcChannelCache() override {
233 for (GrpcChannelCache* cache : caches_) {
234 delete cache;
235 }
236 }
237
ListWorkers(std::vector<string> * workers)238 void ListWorkers(std::vector<string>* workers) override {
239 for (GrpcChannelCache* cache : caches_) {
240 cache->ListWorkers(workers);
241 }
242 }
243
ListWorkersInJob(const string & job_name,std::vector<string> * workers)244 void ListWorkersInJob(const string& job_name,
245 std::vector<string>* workers) override {
246 for (GrpcChannelCache* cache : caches_) {
247 cache->ListWorkersInJob(job_name, workers);
248 }
249 }
250
TranslateTask(const string & target)251 string TranslateTask(const string& target) override {
252 mutex_lock l(mu_); // could use reader lock
253 GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
254 if (cache == nullptr) {
255 for (GrpcChannelCache* c : caches_) {
256 string r = c->TranslateTask(target);
257 if (!r.empty()) {
258 target_caches_.insert({target, c});
259 cache = c;
260 break;
261 }
262 }
263 }
264 CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
265 << target;
266 return cache->TranslateTask(target);
267 }
268
269 protected:
FindChannelOnce(const string & target)270 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
271 for (GrpcChannelCache* cache : caches_) {
272 SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
273 if (ch) {
274 mutex_lock l(mu_);
275 target_caches_.insert({target, cache});
276 return ch;
277 }
278 }
279 return nullptr;
280 }
281
282 private:
283 // List of channels used by this MultiGrpcChannelCache.
284 const std::vector<GrpcChannelCache*> caches_;
285
286 mutex mu_;
287 // Cache of channels keyed by the target they are handling.
288 // The same GrpcChannelCache can appear multiple times in the cache.
289 std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
290 };
291
292 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
293 public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func)294 SparseGrpcChannelCache(const string& job_id,
295 const std::map<int, string>& host_ports,
296 ChannelCreationFunction channel_func)
297 : job_id_(job_id),
298 host_ports_(host_ports),
299 channel_func_(std::move(channel_func)) {
300 LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
301 }
~SparseGrpcChannelCache()302 ~SparseGrpcChannelCache() override {}
303
ListWorkers(std::vector<string> * workers)304 void ListWorkers(std::vector<string>* workers) override {
305 workers->reserve(workers->size() + host_ports_.size());
306 for (const auto& id_host_port : host_ports_) {
307 workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
308 }
309 }
310
ListWorkersInJob(const string & job_name,std::vector<string> * workers)311 void ListWorkersInJob(const string& job_name,
312 std::vector<string>* workers) override {
313 if (job_name == job_id_) {
314 ListWorkers(workers);
315 }
316 }
317
TranslateTask(const string & target)318 string TranslateTask(const string& target) override {
319 DeviceNameUtils::ParsedName parsed;
320 if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
321 LOG(WARNING) << "Invalid target: " << target;
322 return "";
323 }
324
325 if (!parsed.has_job || parsed.job != job_id_) {
326 return "";
327 }
328 if (!parsed.has_replica || parsed.replica != 0) {
329 LOG(WARNING) << "Replica ID must be 0 in target: " << target;
330 return "";
331 }
332 int32 task = parsed.has_task ? parsed.task : -1;
333 auto iter = host_ports_.find(task);
334 if (iter == host_ports_.end()) {
335 LOG(WARNING) << "Task " << task << " was not defined in sparse job "
336 << job_id_ << ": " << target;
337 return "";
338 }
339 return iter->second;
340 }
341
342 protected:
FindChannelOnce(const string & target)343 SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
344 const string host_port = TranslateTask(target);
345 if (host_port.empty()) {
346 return nullptr;
347 }
348 return channel_func_(host_port);
349 }
350
351 private:
ToString()352 string ToString() {
353 std::vector<string> task_strings;
354 task_strings.reserve(host_ports_.size());
355 for (const auto& id_host_port : host_ports_) {
356 task_strings.emplace_back(
357 strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
358 }
359 return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
360 "}");
361 }
362
363 const string job_id_;
364 const std::map<int, string> host_ports_;
365 const ChannelCreationFunction channel_func_;
366 TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
367 };
368
369 } // namespace
370
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func)371 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
372 ChannelCreationFunction channel_func) {
373 const int num_jobs = spec.host_ports_jobs().size();
374 if (!num_jobs) {
375 LOG(ERROR) << "Empty channel spec.";
376 return nullptr;
377 }
378 std::vector<GrpcChannelCache*> caches;
379 caches.reserve(num_jobs);
380 for (auto& job : spec.host_ports_jobs()) {
381 caches.push_back(
382 new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
383 }
384 return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
385 }
386
387 } // end namespace tensorflow
388