• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
17 
18 #include <cstdlib>
19 #include <limits>
20 #include <map>
21 #include <unordered_map>
22 
23 #include "grpcpp/create_channel.h"
24 #include "absl/strings/escaping.h"
25 #include "absl/strings/str_split.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/status.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/strings/numbers.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36 #include "tensorflow/core/platform/types.h"
37 #include "tensorflow/core/util/device_name_utils.h"
38 
39 namespace tensorflow {
40 
41 namespace {
42 
MakeAddress(const string & job,int task)43 string MakeAddress(const string& job, int task) {
44   return strings::StrCat("/job:", job, "/replica:0/task:", task);
45 }
46 
47 // Allows the host to be a raw IP (either v4 or v6).
ValidateHostPortPair(const string & host_port)48 Status ValidateHostPortPair(const string& host_port) {
49   string bns_prefix = "/bns/";
50   if (host_port.substr(0, bns_prefix.length()) == bns_prefix) {
51     return Status::OK();
52   }
53   uint32 port;
54   auto colon_index = host_port.find_last_of(':');
55   if (!strings::safe_strtou32(host_port.substr(colon_index + 1), &port) ||
56       host_port.substr(0, colon_index).find('/') != string::npos) {
57     return errors::InvalidArgument("Could not interpret \"", host_port,
58                                    "\" as a host-port pair.");
59   }
60   return Status::OK();
61 }
62 
CreateDefaultChannelArguments()63 ::grpc::ChannelArguments* CreateDefaultChannelArguments() {
64   ::grpc::ChannelArguments* args = new ::grpc::ChannelArguments();
65   const char* env = std::getenv("TF_GRPC_DEFAULT_OPTIONS");
66   if (env != nullptr) {
67     for (auto& grpc_option : absl::StrSplit(env, ',')) {
68       std::vector<string> name_value = absl::StrSplit(grpc_option, '=');
69       if (name_value.size() != 2) {
70         LOG(ERROR) << "Invalid GRPC options format: " << grpc_option;
71         continue;
72       }
73       VLOG(3) << "Setting GRPC default for '" << name_value[0] << "' to '"
74               << name_value[1] << "'";
75       if (name_value[1].size() >= 2 && name_value[1][0] == '"') {
76         string ue_value = name_value[1].substr(1, name_value[1].size() - 2);
77         string value;
78         string error;
79         if (!absl::CUnescape(ue_value, &value, &error)) {
80           LOG(ERROR) << "Failed to parse escaped string for " << grpc_option
81                      << ": " << error;
82         } else {
83           args->SetString(name_value[0], value);
84         }
85       } else {
86         int64 value;
87         if (strings::safe_strto64(name_value[1], &value)) {
88           args->SetInt(name_value[0], value);
89         } else {
90           LOG(ERROR) << "Invalid integer value: " << grpc_option;
91         }
92       }
93     }
94   }
95   return args;
96 }
97 
GetDefaultChannelArguments()98 const ::grpc::ChannelArguments* GetDefaultChannelArguments() {
99   static const ::grpc::ChannelArguments* args = CreateDefaultChannelArguments();
100   return args;
101 }
102 
103 }  // namespace
104 
GetChannelArguments(const RPCOptions * rpc_options)105 ::grpc::ChannelArguments GetChannelArguments(const RPCOptions* rpc_options) {
106   // TODO(mrry): Implement secure channels.
107   ::grpc::ChannelArguments args = *GetDefaultChannelArguments();
108   args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
109   // NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
110   // on connection failure, which makes our tests time out.
111   args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
112   if (rpc_options != nullptr) {
113     if (rpc_options->compression_algorithm() == "deflate") {
114       args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
115       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
116                   rpc_options->compression_level());
117       VLOG(5) << "Setting GRPC compression : algo='"
118               << rpc_options->compression_algorithm()
119               << "' level=" << rpc_options->compression_level();
120     } else if (rpc_options->compression_algorithm() == "gzip") {
121       args.SetCompressionAlgorithm(GRPC_COMPRESS_GZIP);
122       args.SetInt(GRPC_COMPRESSION_CHANNEL_DEFAULT_LEVEL,
123                   rpc_options->compression_level());
124       VLOG(5) << "Setting GRPC compression : algo='"
125               << rpc_options->compression_algorithm()
126               << "' level=" << rpc_options->compression_level();
127     } else if (!rpc_options->compression_algorithm().empty()) {
128       LOG(ERROR) << "Invalid compression algorithm: "
129                  << rpc_options->compression_algorithm();
130     }
131     if (rpc_options->disable_session_connection_sharing()) {
132       VLOG(5) << "Disabling TCP connection sharing";
133       args.SetInt(GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL, true);
134     }
135   }
136   return args;
137 }
138 
NewHostPortGrpcChannel(const string & target,const RPCOptions * rpc_options,SharedGrpcChannelPtr * channel_pointer)139 Status NewHostPortGrpcChannel(const string& target,
140                               const RPCOptions* rpc_options,
141                               SharedGrpcChannelPtr* channel_pointer) {
142   // Minimally ensure that the target is valid
143   TF_RETURN_IF_ERROR(ValidateHostPortPair(target));
144 
145   ::grpc::ChannelArguments args = GetChannelArguments(rpc_options);
146   *channel_pointer = ::grpc::CreateCustomChannel(
147       "dns:///" + target, ::grpc::InsecureChannelCredentials(), args);
148   return Status::OK();
149 }
150 
ConvertToChannelCreationFunction(const std::function<Status (string,const RPCOptions *,SharedGrpcChannelPtr *)> & new_channel_func_ptr)151 ChannelCreationFunction ConvertToChannelCreationFunction(
152     const std::function<Status(string, const RPCOptions*,
153                                SharedGrpcChannelPtr*)>& new_channel_func_ptr) {
154   return [new_channel_func_ptr](const string& target) -> SharedGrpcChannelPtr {
155     SharedGrpcChannelPtr channel_ptr;
156     if (new_channel_func_ptr(target, /*rpc_options=*/nullptr, &channel_ptr)
157             .ok()) {
158       return channel_ptr;
159     } else {
160       return nullptr;
161     }
162   };
163 }
164 
AddHostPortsJob(const string & job_id,const std::vector<string> & host_ports)165 Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
166                                         const std::vector<string>& host_ports) {
167   std::map<int, string> host_ports_map;
168   for (size_t i = 0; i < host_ports.size(); ++i) {
169     host_ports_map[i] = host_ports[i];
170   }
171   return AddHostPortsJob(job_id, host_ports_map);
172 }
173 
AddHostPortsJob(const string & job_id,const std::map<int,string> & host_ports)174 Status GrpcChannelSpec::AddHostPortsJob(
175     const string& job_id, const std::map<int, string>& host_ports) {
176   if (!job_ids_.insert(job_id).second) {
177     return errors::InvalidArgument(
178         "Duplicate job ID in cluster specification: ", job_id);
179   }
180   for (const auto& id_host_port : host_ports) {
181     TF_RETURN_IF_ERROR(ValidateHostPortPair(id_host_port.second));
182   }
183   host_ports_jobs_.emplace_back(job_id, host_ports);
184   return Status::OK();
185 }
186 
187 namespace {
188 
189 // GrpcChannelCache that caches results to FindWorkerChannel() calls.
190 class CachingGrpcChannelCache : public GrpcChannelCache {
191  public:
CachingGrpcChannelCache()192   CachingGrpcChannelCache() {}
193 
~CachingGrpcChannelCache()194   ~CachingGrpcChannelCache() override {}
195 
FindWorkerChannel(const string & target)196   SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
197     SharedGrpcChannelPtr ch = nullptr;
198     {
199       mutex_lock l(mu_);  // could use reader lock
200       ch = gtl::FindPtrOrNull(channels_, target);
201       if (ch) {
202         return ch;
203       }
204     }
205     ch = FindChannelOnce(target);
206     if (ch) {
207       mutex_lock l(mu_);
208       channels_.insert({target, ch});
209     }
210     return ch;
211   }
212 
213  protected:
214   // Find the ClientChannel for "target".  Only called when no channel was
215   // found in the channels_ cache for "target".  A non nullptr result will be
216   // cached in channels_.
217   virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
218 
219  private:
220   // TODO(zhifengc): Eviction when the map becomes too big.
221   mutex mu_;
222   std::unordered_map<string, SharedGrpcChannelPtr> channels_ TF_GUARDED_BY(mu_);
223 };
224 
225 // A ChannelCache that is the union of multiple ChannelCaches.
226 // Takes ownership of the caches passed to the constructor.
227 class MultiGrpcChannelCache : public CachingGrpcChannelCache {
228  public:
MultiGrpcChannelCache(const std::vector<GrpcChannelCache * > & caches)229   explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
230       : CachingGrpcChannelCache(), caches_(caches) {}
231 
~MultiGrpcChannelCache()232   ~MultiGrpcChannelCache() override {
233     for (GrpcChannelCache* cache : caches_) {
234       delete cache;
235     }
236   }
237 
ListWorkers(std::vector<string> * workers)238   void ListWorkers(std::vector<string>* workers) override {
239     for (GrpcChannelCache* cache : caches_) {
240       cache->ListWorkers(workers);
241     }
242   }
243 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)244   void ListWorkersInJob(const string& job_name,
245                         std::vector<string>* workers) override {
246     for (GrpcChannelCache* cache : caches_) {
247       cache->ListWorkersInJob(job_name, workers);
248     }
249   }
250 
TranslateTask(const string & target)251   string TranslateTask(const string& target) override {
252     mutex_lock l(mu_);  // could use reader lock
253     GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
254     if (cache == nullptr) {
255       for (GrpcChannelCache* c : caches_) {
256         string r = c->TranslateTask(target);
257         if (!r.empty()) {
258           target_caches_.insert({target, c});
259           cache = c;
260           break;
261         }
262       }
263     }
264     CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
265                  << target;
266     return cache->TranslateTask(target);
267   }
268 
269  protected:
FindChannelOnce(const string & target)270   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
271     for (GrpcChannelCache* cache : caches_) {
272       SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
273       if (ch) {
274         mutex_lock l(mu_);
275         target_caches_.insert({target, cache});
276         return ch;
277       }
278     }
279     return nullptr;
280   }
281 
282  private:
283   // List of channels used by this MultiGrpcChannelCache.
284   const std::vector<GrpcChannelCache*> caches_;
285 
286   mutex mu_;
287   // Cache of channels keyed by the target they are handling.
288   // The same GrpcChannelCache can appear multiple times in the cache.
289   std::unordered_map<string, GrpcChannelCache*> target_caches_
290       TF_GUARDED_BY(mu_);
291 };
292 
293 class SparseGrpcChannelCache : public CachingGrpcChannelCache {
294  public:
SparseGrpcChannelCache(const string & job_id,const std::map<int,string> & host_ports,ChannelCreationFunction channel_func)295   SparseGrpcChannelCache(const string& job_id,
296                          const std::map<int, string>& host_ports,
297                          ChannelCreationFunction channel_func)
298       : job_id_(job_id),
299         host_ports_(host_ports),
300         channel_func_(std::move(channel_func)) {
301     LOG(INFO) << "Initialize GrpcChannelCache for job " << ToString();
302   }
~SparseGrpcChannelCache()303   ~SparseGrpcChannelCache() override {}
304 
ListWorkers(std::vector<string> * workers)305   void ListWorkers(std::vector<string>* workers) override {
306     workers->reserve(workers->size() + host_ports_.size());
307     for (const auto& id_host_port : host_ports_) {
308       workers->emplace_back(MakeAddress(job_id_, id_host_port.first));
309     }
310   }
311 
ListWorkersInJob(const string & job_name,std::vector<string> * workers)312   void ListWorkersInJob(const string& job_name,
313                         std::vector<string>* workers) override {
314     if (job_name == job_id_) {
315       ListWorkers(workers);
316     }
317   }
318 
TranslateTask(const string & target)319   string TranslateTask(const string& target) override {
320     DeviceNameUtils::ParsedName parsed;
321     if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
322       LOG(WARNING) << "Invalid target: " << target;
323       return "";
324     }
325 
326     if (!parsed.has_job || parsed.job != job_id_) {
327       return "";
328     }
329     if (!parsed.has_replica || parsed.replica != 0) {
330       LOG(WARNING) << "Replica ID must be 0 in target: " << target;
331       return "";
332     }
333     int32 task = parsed.has_task ? parsed.task : -1;
334     auto iter = host_ports_.find(task);
335     if (iter == host_ports_.end()) {
336       LOG(WARNING) << "Task " << task << " was not defined in sparse job "
337                    << job_id_ << ": " << target;
338       return "";
339     }
340     return iter->second;
341   }
342 
343  protected:
FindChannelOnce(const string & target)344   SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
345     const string host_port = TranslateTask(target);
346     if (host_port.empty()) {
347       return nullptr;
348     }
349     return channel_func_(host_port);
350   }
351 
352  private:
ToString()353   string ToString() {
354     std::vector<string> task_strings;
355     task_strings.reserve(host_ports_.size());
356     for (const auto& id_host_port : host_ports_) {
357       task_strings.emplace_back(
358           strings::StrCat(id_host_port.first, " -> ", id_host_port.second));
359     }
360     return strings::StrCat(job_id_, " -> {", absl::StrJoin(task_strings, ", "),
361                            "}");
362   }
363 
364   const string job_id_;
365   const std::map<int, string> host_ports_;
366   const ChannelCreationFunction channel_func_;
367   TF_DISALLOW_COPY_AND_ASSIGN(SparseGrpcChannelCache);
368 };
369 
370 }  // namespace
371 
NewGrpcChannelCache(const GrpcChannelSpec & spec,ChannelCreationFunction channel_func)372 GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec,
373                                       ChannelCreationFunction channel_func) {
374   const int num_jobs = spec.host_ports_jobs().size();
375   if (!num_jobs) {
376     LOG(ERROR) << "Empty channel spec.";
377     return nullptr;
378   }
379   std::vector<GrpcChannelCache*> caches;
380   caches.reserve(num_jobs);
381   for (auto& job : spec.host_ports_jobs()) {
382     caches.push_back(
383         new SparseGrpcChannelCache(job.job_id, job.host_ports, channel_func));
384   }
385   return caches.size() == 1 ? caches[0] : new MultiGrpcChannelCache(caches);
386 }
387 
388 }  // end namespace tensorflow
389