• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17 
18 #include <cstdlib>
19 #include <limits>
20 #include <map>
21 #include <unordered_map>
22 
23 #include "grpcpp/create_channel.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 
41 namespace {
42 
MakeAddress(const string & job,int task)43 string MakeAddress(const string& job, int task) {
44   return strings::StrCat("/job:", job, "/replica:0/task:", task);
45 }
46 
47 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)48 Status ValidateHostPortPair(const string& host_port) {
49   string bns_prefix = "/bns/";
50   if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
51     return Status::OK();
52   }
53   uint32 port;
54   auto colon_index = host_port.find_last_of(':');
55   if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
56       host_port.substr(0, colon_index).find("/") != string::npos) {
57     return errors::InvalidArgument("Could not interpret \"", host_port,
58                                    "\" as a host-port pair.");
59   }
60   return Status::OK();
61 }
62 
CreateDefaultChannelArguments()63 ::grpc::ChannelArguments* CreateDefaultChannelArguments() {
64   ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
65   const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
66   if (env != nullptr) {
67     for (auto& grpc_option : absl::StrSplit(env, ',')) {
68       std::vector<string> name_value = absl::StrSplit(grpc_option, '=');
69       if (name_value.size() != 2) {
70         LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
71         continue;
72       }
73       VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
74               << name_value[1] << "'";
75       if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
76         string ue_value = name_value[1].substr(1, name_value[1].size() - 2);
77         string value;
78         string error;
79         if (!absl::CUnescape(ue_value, &value, &error)) {
80           LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
81                      << ": " << error;
82         } else {
83           args->SetString(name_value[0], value);
84         }
85       } else {
86         int64 value;
87         if (strings::safe_strto64(name_value[1], &value)) {
88           args->SetInt(name_value[0], value);
89         } else {
90           LOG(ERROR) << "Invalid integer value: " << grpc_option;
91         }
92       }
93     }
94   }
95   return args;
96 }
97 
GetDefaultChannelArguments()98 const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
99   static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
100   return args;
101 }
102 
103 }  // namespace
104 
GetChannelArguments(const RPCOptions * rpc_options)105 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
106   // TODO(mrry): Implement secure channels.
107   ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
108   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
109   // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
110   // on connection failure, which makes our tests time out.
111   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
112   if (rpc_options != nullptr) {
113     if (rpc_options->compression_algorithm() == "deflate") {
114       args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
115       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
116                   rpc_options->compression_level());
117       VLOG(5) << "Setting GRPC compression : algo='"
118               << rpc_options->compression_algorithm()
119               << "' level=" << rpc_options->compression_level();
120     } else if (rpc_options->compression_algorithm() == "gzip") {
121       args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
122       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
123                   rpc_options->compression_level());
124       VLOG(5) << "Setting GRPC compression : algo='"
125               << rpc_options->compression_algorithm()
126               << "' level=" << rpc_options->compression_level();
127     } else if (!rpc_options->compression_algorithm().empty()) {
128       LOG(ERROR) << "Invalid compression algorithm: "
129                  << rpc_options->compression_algorithm();
130     }
131     if (rpc_options->disable_session_connection_sharing()) {
132       VLOG(5) << "Disabling TCP connection sharing";
133       args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
134     }
135   }
136   return args;
137 }
138 
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)139 Status NewHostPortGrpcChannel(const string& target,
140                               const RPCOptions* rpc_options,
141                               SharedGrpcChannelPtr* channel_pointer) {
142   // Minimally ensure that the target is valid
143   TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
144 
145   ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
146   *channel_pointer = ::grpc::CreateCustomChannel(
147       "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
148   return Status::OK();
149 }
150 
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)151 ChannelCreationFunction ConvertToChannelCreationFunction(
152     const std::function<Status(string, const RPCOptions*,
153                                SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
154   return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
155     SharedGrpcChannelPtr channel_ptr;
156     if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
157             .ok()) {
158       return channel_ptr;
159     } else {
160       return nullptr;
161     }
162   };
163 }
164 
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)165 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
166                                         const std::vector<string>& host_ports) {
167   std::map<int, string> host_ports_map;
168   for (size_t i = 0; i < host_ports.size(); ++i) {
169     host_ports_map[i] = host_ports[i];
170   }
171   return AddHostPortsJob(job_id, host_ports_map);
172 }
173 
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)174 Status GrpcChannelSpec::AddHostPortsJob(
175     const string& job_id, const std::map<int, string>& host_ports) {
176   if (!job_ids_.insert(job_id).second) {
177     return errors::InvalidArgument(
178         "Duplicate job ID in cluster specification: ", job_id);
179   }
180   for (const auto& id_host_port : host_ports) {
181     TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
182   }
183   host_ports_jobs_.emplace_back(job_id, host_ports);
184   return Status::OK();
185 }
186 
187 namespace {
188 
189 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
190 class CachingGrpcChannelCache : public GrpcChannelCache {
191  public:
CachingGrpcChannelCache()192   CachingGrpcChannelCache() {}
193 
~CachingGrpcChannelCache()194   ~CachingGrpcChannelCache() override {}
195 
FindWorkerChannel(const string & target)196   SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
197     SharedGrpcChannelPtr ch = nullptr;
198     {
199       mutex_lock l(mu_);  // could use reader lock
200       ch = gtl::FindPtrOrNull(channels_, target);
201       if (ch) {
202         return ch;
203       }
204     }
205     ch = FindChannelOnce(target);
206     if (ch) {
207       mutex_lock l(mu_);
208       channels_.insert({target, ch});
209     }
210     return ch;
211   }
212 
213  protected:
214   // Find the ClientChannel for "target".  Only called when no channel was
215   // found in the channels_ cache for "target".  A non nullptr result will be
216   // cached in channels_.
217   virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
218 
219  private:
220   // TODO(zhifengc): Eviction when the map becomes too big.
221   mutex mu_;
222   std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
223 };
224 
225 // A ChannelCache that is the union of multiple ChannelCaches.
226 // Takes ownership of the caches passed to the constructor.
227 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
228  public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches)229   explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
230       : CachingGrpcChannelCache(), caches_(caches) {}
231 
~MultiGrpcChannelCache()232   ~MultiGrpcChannelCache() override {
233     for (GrpcChannelCache* cache : caches_) {
234       delete cache;
235     }
236   }
237 
ListWorkers(std::vector<string> * workers)238   void ListWorkers(std::vector<string>* workers) override {
239     for (GrpcChannelCache* cache : caches_) {
240       cache->ListWorkers(workers);
241     }
242   }
243 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)244   void ListWorkersInJob(const string& job_name,
245                         std::vector<string>* workers) override {
246     for (GrpcChannelCache* cache : caches_) {
247       cache->ListWorkersInJob(job_name, workers);
248     }
249   }
250 
TranslateTask(const string & target)251   string TranslateTask(const string& target) override {
252     mutex_lock l(mu_);  // could use reader lock
253     GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
254     if (cache == nullptr) {
255       for (GrpcChannelCache* c : caches_) {
256         string r = c->TranslateTask(target);
257         if (!r.empty()) {
258           target_caches_.insert({target, c});
259           cache = c;
260           break;
261         }
262       }
263     }
264     CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
265                  << target;
266     return cache->TranslateTask(target);
267   }
268 
269  protected:
FindChannelOnce(const string & target)270   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
271     for (GrpcChannelCache* cache : caches_) {
272       SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
273       if (ch) {
274         mutex_lock l(mu_);
275         target_caches_.insert({target, cache});
276         return ch;
277       }
278     }
279     return nullptr;
280   }
281 
282  private:
283   // List of channels used by this MultiGrpcChannelCache.
284   const std::vector<GrpcChannelCache*> caches_;
285 
286   mutex mu_;
287   // Cache of channels keyed by the target they are handling.
288   // The same GrpcChannelCache can appear multiple times in the cache.
289   std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
290 };
291 
292 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
293  public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func)294   SparseGrpcChannelCache(const string& job_id,
295                          const std::map<int, string>& host_ports,
296                          ChannelCreationFunction channel_func)
297       : job_id_(job_id),
298         host_ports_(host_ports),
299         channel_func_(std::move(channel_func)) {
300     LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
301   }
~SparseGrpcChannelCache()302   ~SparseGrpcChannelCache() override {}
303 
ListWorkers(std::vector<string> * workers)304   void ListWorkers(std::vector<string>* workers) override {
305     workers->reserve(workers->size() + host_ports_.size());
306     for (const auto& id_host_port : host_ports_) {
307       workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
308     }
309   }
310 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)311   void ListWorkersInJob(const string& job_name,
312                         std::vector<string>* workers) override {
313     if (job_name == job_id_) {
314       ListWorkers(workers);
315     }
316   }
317 
TranslateTask(const string & target)318   string TranslateTask(const string& target) override {
319     DeviceNameUtils::ParsedName parsed;
320     if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
321       LOG(WARNING) << "Invalid target: " << target;
322       return "";
323     }
324 
325     if (!parsed.has_job || parsed.job != job_id_) {
326       return "";
327     }
328     if (!parsed.has_replica || parsed.replica != 0) {
329       LOG(WARNING) << "Replica ID must be 0 in target: " << target;
330       return "";
331     }
332     int32 task = parsed.has_task ? parsed.task : -1;
333     auto iter = host_ports_.find(task);
334     if (iter == host_ports_.end()) {
335       LOG(WARNING) << "Task " << task << " was not defined in sparse job "
336                    << job_id_ << ": " << target;
337       return "";
338     }
339     return iter->second;
340   }
341 
342  protected:
FindChannelOnce(const string & target)343   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
344     const string host_port = TranslateTask(target);
345     if (host_port.empty()) {
346       return nullptr;
347     }
348     return channel_func_(host_port);
349   }
350 
351  private:
ToString()352   string ToString() {
353     std::vector<string> task_strings;
354     task_strings.reserve(host_ports_.size());
355     for (const auto& id_host_port : host_ports_) {
356       task_strings.emplace_back(
357           strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
358     }
359     return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
360                            "}");
361   }
362 
363   const string job_id_;
364   const std::map<int, string> host_ports_;
365   const ChannelCreationFunction channel_func_;
366   TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
367 };
368 
369 }  // namespace
370 
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func)371 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
372                                       ChannelCreationFunction channel_func) {
373   const int num_jobs = spec.host_ports_jobs().size();
374   if (!num_jobs) {
375     LOG(ERROR) << "Empty channel spec.";
376     return nullptr;
377   }
378   std::vector<GrpcChannelCache*> caches;
379   caches.reserve(num_jobs);
380   for (auto& job : spec.host_ports_jobs()) {
381     caches.push_back(
382         new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
383   }
384   return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
385 }
386 
387 }  // end namespace tensorflow
388