• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/device_name_utils.h"
17 
18 #include <algorithm>
19 
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace tensorflow {
26 
IsAlpha(char c)27 static bool IsAlpha(char c) {
28   return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
29 }
30 
IsAlphaNumOrUnderscore(char c)31 static bool IsAlphaNumOrUnderscore(char c) {
32   return IsAlpha(c) || (c >= '0' && c <= '9') || c == '_';
33 }
34 
35 // Returns true iff "in" is a valid job name.
IsJobName(StringPiece in)36 static bool IsJobName(StringPiece in) {
37   return !in.empty() && IsAlpha(in.front()) &&
38          std::all_of(in.begin(), in.end(), IsAlphaNumOrUnderscore);
39 }
40 
ConsumePrefix(StringPiece * in,string * out,StringPiece prefix_terminators)41 static bool ConsumePrefix(StringPiece* in, string* out,
42                           StringPiece prefix_terminators) {
43   if (in->empty() || !IsAlpha(in->front())) return false;
44   const auto end_it =
45       std::find_first_of(in->begin(), in->end(), prefix_terminators.begin(),
46                          prefix_terminators.end());
47   if (!std::all_of(in->begin(), end_it, IsAlphaNumOrUnderscore)) {
48     return false;
49   }
50   out->assign(in->begin(), end_it);
51   in->remove_prefix(end_it - in->begin());
52   return true;
53 }
54 
55 // Returns true and fills in "*job" iff "*in" starts with a job name.
ConsumeJobName(StringPiece * in,string * job)56 static bool ConsumeJobName(StringPiece* in, string* job) {
57   return ConsumePrefix(in, job, "/");
58 }
59 
60 // Returns true and fills in "*device_type" iff "*in" starts with a device type
61 // name.
ConsumeDeviceType(StringPiece * in,string * device_type)62 static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
63   return ConsumePrefix(in, device_type, "/:");
64 }
65 
66 // Returns true and fills in "*val" iff "*in" starts with a decimal
67 // number.
ConsumeNumber(StringPiece * in,int * val)68 static bool ConsumeNumber(StringPiece* in, int* val) {
69   uint64 tmp;
70   if (str_util::ConsumeLeadingDigits(in, &tmp)) {
71     *val = tmp;
72     return true;
73   } else {
74     return false;
75   }
76 }
77 
78 // Returns a fully qualified device name given the parameters.
DeviceName(const string & job,int replica,int task,const string & device_prefix,const string & device_type,int id)79 static string DeviceName(const string& job, int replica, int task,
80                          const string& device_prefix, const string& device_type,
81                          int id) {
82   CHECK(IsJobName(job)) << job;
83   CHECK_LE(0, replica);
84   CHECK_LE(0, task);
85   CHECK(!device_type.empty());
86   CHECK_LE(0, id);
87   return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
88                          device_prefix, device_type, ":", id);
89 }
90 
91 /* static */
FullName(const string & job,int replica,int task,const string & type,int id)92 string DeviceNameUtils::FullName(const string& job, int replica, int task,
93                                  const string& type, int id) {
94   return DeviceName(job, replica, task, "/device:", type, id);
95 }
96 
97 namespace {
LegacyName(const string & job,int replica,int task,const string & type,int id)98 string LegacyName(const string& job, int replica, int task, const string& type,
99                   int id) {
100   return DeviceName(job, replica, task, "/", absl::AsciiStrToLower(type), id);
101 }
102 }  // anonymous namespace
103 
ParseFullName(StringPiece fullname,ParsedName * p)104 bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
105   p->Clear();
106   if (fullname == "/") {
107     return true;
108   }
109   while (!fullname.empty()) {
110     bool progress = false;
111     if (absl::ConsumePrefix(&fullname, "/job:")) {
112       p->has_job = !absl::ConsumePrefix(&fullname, "*");
113       if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
114         return false;
115       }
116       progress = true;
117     }
118     if (absl::ConsumePrefix(&fullname, "/replica:")) {
119       p->has_replica = !absl::ConsumePrefix(&fullname, "*");
120       if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
121         return false;
122       }
123       progress = true;
124     }
125     if (absl::ConsumePrefix(&fullname, "/task:")) {
126       p->has_task = !absl::ConsumePrefix(&fullname, "*");
127       if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
128         return false;
129       }
130       progress = true;
131     }
132     if (absl::ConsumePrefix(&fullname, "/device:")) {
133       p->has_type = !absl::ConsumePrefix(&fullname, "*");
134       if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
135         return false;
136       }
137       if (!absl::ConsumePrefix(&fullname, ":")) {
138         p->has_id = false;
139       } else {
140         p->has_id = !absl::ConsumePrefix(&fullname, "*");
141         if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
142           return false;
143         }
144       }
145       progress = true;
146     }
147 
148     // Handle legacy naming convention for cpu and gpu.
149     if (absl::ConsumePrefix(&fullname, "/cpu:") ||
150         absl::ConsumePrefix(&fullname, "/CPU:")) {
151       p->has_type = true;
152       p->type = "CPU";  // Treat '/cpu:..' as uppercase '/device:CPU:...'
153       p->has_id = !absl::ConsumePrefix(&fullname, "*");
154       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
155         return false;
156       }
157       progress = true;
158     }
159     if (absl::ConsumePrefix(&fullname, "/gpu:") ||
160         absl::ConsumePrefix(&fullname, "/GPU:")) {
161       p->has_type = true;
162       p->type = "GPU";  // Treat '/gpu:..' as uppercase '/device:GPU:...'
163       p->has_id = !absl::ConsumePrefix(&fullname, "*");
164       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
165         return false;
166       }
167       progress = true;
168     }
169 
170     if (!progress) {
171       return false;
172     }
173   }
174   return true;
175 }
176 
177 namespace {
178 
CompleteName(const DeviceNameUtils::ParsedName & parsed_basename,DeviceNameUtils::ParsedName * parsed_name)179 void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
180                   DeviceNameUtils::ParsedName* parsed_name) {
181   if (!parsed_name->has_job) {
182     parsed_name->job = parsed_basename.job;
183     parsed_name->has_job = true;
184   }
185   if (!parsed_name->has_replica) {
186     parsed_name->replica = parsed_basename.replica;
187     parsed_name->has_replica = true;
188   }
189   if (!parsed_name->has_task) {
190     parsed_name->task = parsed_basename.task;
191     parsed_name->has_task = true;
192   }
193   if (!parsed_name->has_type) {
194     parsed_name->type = parsed_basename.type;
195     parsed_name->has_type = true;
196   }
197   if (!parsed_name->has_id) {
198     parsed_name->id = parsed_basename.id;
199     parsed_name->has_id = true;
200   }
201 }
202 
203 }  // namespace
204 
205 /* static */
CanonicalizeDeviceName(StringPiece fullname,StringPiece basename,string * canonical_name)206 Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname,
207                                                StringPiece basename,
208                                                string* canonical_name) {
209   *canonical_name = "";
210   ParsedName parsed_basename;
211   if (!ParseFullName(basename, &parsed_basename)) {
212     return errors::InvalidArgument("Could not parse basename: ", basename,
213                                    " into a device specification.");
214   }
215   if (!(parsed_basename.has_job && parsed_basename.has_replica &&
216         parsed_basename.has_task && parsed_basename.has_type &&
217         parsed_basename.has_id)) {
218     return errors::InvalidArgument("Basename: ", basename,
219                                    " should be fully "
220                                    "specified.");
221   }
222   ParsedName parsed_name;
223   if (ParseLocalName(fullname, &parsed_name)) {
224     CompleteName(parsed_basename, &parsed_name);
225     *canonical_name = ParsedNameToString(parsed_name);
226     return Status::OK();
227   }
228   if (ParseFullName(fullname, &parsed_name)) {
229     CompleteName(parsed_basename, &parsed_name);
230     *canonical_name = ParsedNameToString(parsed_name);
231     return Status::OK();
232   }
233   return errors::InvalidArgument("Could not parse ", fullname,
234                                  " into a device "
235                                  "specification.");
236 }
237 
238 /* static */
ParsedNameToString(const ParsedName & pn)239 string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
240   string buf;
241   if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
242   if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
243   if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
244   if (pn.has_type) {
245     strings::StrAppend(&buf, "/device:", pn.type, ":");
246     if (pn.has_id) {
247       strings::StrAppend(&buf, pn.id);
248     } else {
249       strings::StrAppend(&buf, "*");
250     }
251   }
252   return buf;
253 }
254 
255 /* static */
IsSpecification(const ParsedName & less_specific,const ParsedName & more_specific)256 bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
257                                       const ParsedName& more_specific) {
258   if (less_specific.has_job &&
259       (!more_specific.has_job || (less_specific.job != more_specific.job))) {
260     return false;
261   }
262   if (less_specific.has_replica &&
263       (!more_specific.has_replica ||
264        (less_specific.replica != more_specific.replica))) {
265     return false;
266   }
267   if (less_specific.has_task &&
268       (!more_specific.has_task || (less_specific.task != more_specific.task))) {
269     return false;
270   }
271   if (less_specific.has_type &&
272       (!more_specific.has_type || (less_specific.type != more_specific.type))) {
273     return false;
274   }
275   if (less_specific.has_id &&
276       (!more_specific.has_id || (less_specific.id != more_specific.id))) {
277     return false;
278   }
279   return true;
280 }
281 
282 /* static */
AreCompatibleDevNames(const ParsedName & a,const ParsedName & b)283 bool DeviceNameUtils::AreCompatibleDevNames(const ParsedName& a,
284                                             const ParsedName& b) {
285   if (a.has_job && b.has_job && (a.job != b.job)) {
286     return false;
287   }
288   if (a.has_replica && b.has_replica && (a.replica != b.replica)) {
289     return false;
290   }
291   if (a.has_task && b.has_task && (a.task != b.task)) {
292     return false;
293   }
294   if (a.has_type && b.has_type && (a.type != b.type)) {
295     return false;
296   }
297   if (a.has_id && b.has_id && (a.id != b.id)) {
298     return false;
299   }
300   return true;
301 }
302 
EnsureSpecification(ParsedName * more_specific,const ParsedName & less_specific)303 void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
304                                           const ParsedName& less_specific) {
305   if (less_specific.has_job) {
306     more_specific->has_job = true;
307     more_specific->job = less_specific.job;
308   }
309   if (less_specific.has_replica) {
310     more_specific->has_replica = true;
311     more_specific->replica = less_specific.replica;
312   }
313   if (less_specific.has_task) {
314     more_specific->has_task = true;
315     more_specific->task = less_specific.task;
316   }
317   if (less_specific.has_type) {
318     more_specific->has_type = true;
319     more_specific->type = less_specific.type;
320   }
321   if (less_specific.has_id) {
322     more_specific->has_id = true;
323     more_specific->id = less_specific.id;
324   }
325 }
326 
327 /* static */
IsCompleteSpecification(const ParsedName & pattern,const ParsedName & name)328 bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
329                                               const ParsedName& name) {
330   CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
331         name.has_id);
332 
333   if (pattern.has_job && (pattern.job != name.job)) return false;
334   if (pattern.has_replica && (pattern.replica != name.replica)) return false;
335   if (pattern.has_task && (pattern.task != name.task)) return false;
336   if (pattern.has_type && (pattern.type != name.type)) return false;
337   if (pattern.has_id && (pattern.id != name.id)) return false;
338   return true;
339 }
340 
341 namespace {
MergeDevNamesImpl(DeviceNameUtils::ParsedName * target,const DeviceNameUtils::ParsedName & other,bool allow_soft_placement,bool override_conflicts)342 Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target,
343                          const DeviceNameUtils::ParsedName& other,
344                          bool allow_soft_placement, bool override_conflicts) {
345   const auto& ParsedNameToString = DeviceNameUtils::ParsedNameToString;
346   if (other.has_job) {
347     if (target->has_job && target->job != other.job) {
348       return errors::InvalidArgument(
349           "Cannot merge devices with incompatible jobs: '",
350           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
351           "'");
352     } else {
353       target->has_job = other.has_job;
354       target->job = other.job;
355     }
356   }
357 
358   if (other.has_replica) {
359     if (target->has_replica && target->replica != other.replica) {
360       return errors::InvalidArgument(
361           "Cannot merge devices with incompatible replicas: '",
362           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
363           "'");
364     } else {
365       target->has_replica = other.has_replica;
366       target->replica = other.replica;
367     }
368   }
369 
370   if (other.has_task) {
371     if (target->has_task && target->task != other.task) {
372       return errors::InvalidArgument(
373           "Cannot merge devices with incompatible tasks: '",
374           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
375           "'");
376     } else {
377       target->has_task = other.has_task;
378       target->task = other.task;
379     }
380   }
381 
382   if (other.has_type) {
383     if (target->has_type && target->type != other.type) {
384       if (!allow_soft_placement) {
385         return errors::InvalidArgument(
386             "Cannot merge devices with incompatible types: '",
387             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
388             "'");
389       } else if (override_conflicts) {
390         target->type = other.type;
391       } else {
392         target->has_id = false;
393         target->has_type = false;
394         return Status::OK();
395       }
396     } else {
397       target->has_type = other.has_type;
398       target->type = other.type;
399     }
400   }
401 
402   if (other.has_id) {
403     if (target->has_id && target->id != other.id) {
404       if (!allow_soft_placement) {
405         return errors::InvalidArgument(
406             "Cannot merge devices with incompatible ids: '",
407             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
408             "'");
409       } else if (override_conflicts) {
410         target->id = other.id;
411       } else {
412         target->has_id = false;
413         return Status::OK();
414       }
415     } else {
416       target->has_id = other.has_id;
417       target->id = other.id;
418     }
419   }
420 
421   return Status::OK();
422 }
423 
424 }  // namespace
425 
426 /* static */
MergeDevNames(ParsedName * target,const ParsedName & other,bool allow_soft_placement)427 Status DeviceNameUtils::MergeDevNames(ParsedName* target,
428                                       const ParsedName& other,
429                                       bool allow_soft_placement) {
430   return MergeDevNamesImpl(target, other, allow_soft_placement,
431                            /*override_conflicts=*/false);
432 }
433 
434 /* static */
MergeOverrideDevNames(ParsedName * target,const ParsedName & other)435 Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target,
436                                               const ParsedName& other) {
437   return MergeDevNamesImpl(target, other, /*allow_soft_placement=*/true,
438                            /*override_conflicts=*/true);
439 }
440 
441 /* static */
IsSameAddressSpace(const ParsedName & a,const ParsedName & b)442 bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
443                                          const ParsedName& b) {
444   return (a.has_job && b.has_job && (a.job == b.job)) &&
445          (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
446          (a.has_task && b.has_task && (a.task == b.task));
447 }
448 
449 /* static */
IsSameAddressSpace(StringPiece src,StringPiece dst)450 bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
451   ParsedName x;
452   ParsedName y;
453   return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
454          IsSameAddressSpace(x, y);
455 }
456 
457 /* static */
IsDifferentAddressSpace(const ParsedName & a,const ParsedName & b)458 bool DeviceNameUtils::IsDifferentAddressSpace(const ParsedName& a,
459                                               const ParsedName& b) {
460   return (a.has_job && b.has_job && (a.job != b.job)) ||
461          (a.has_replica && b.has_replica && (a.replica != b.replica)) ||
462          (a.has_task && b.has_task && (a.task != b.task));
463 }
464 
465 /* static */
AddressSpace(const ParsedName & name)466 const DeviceNameUtils::ParsedName DeviceNameUtils::AddressSpace(
467     const ParsedName& name) {
468   ParsedName address_space;
469   address_space.has_job = name.has_job;
470   address_space.has_replica = name.has_replica;
471   address_space.has_task = name.has_task;
472   address_space.job = name.job;
473   address_space.replica = name.replica;
474   address_space.task = name.task;
475   return address_space;
476 }
477 
478 /* static */
LocalName(StringPiece type,int id)479 string DeviceNameUtils::LocalName(StringPiece type, int id) {
480   return strings::StrCat("/device:", type, ":", id);
481 }
482 
483 namespace {
484 // Returns the legacy local device name given its "type" and "id" (which is
485 // '/device:type:id').
LegacyLocalName(StringPiece type,int id)486 string LegacyLocalName(StringPiece type, int id) {
487   return strings::StrCat(type, ":", id);
488 }
489 }  // anonymous namespace
490 
491 /* static */
LocalName(StringPiece fullname)492 string DeviceNameUtils::LocalName(StringPiece fullname) {
493   ParsedName x;
494   CHECK(ParseFullName(fullname, &x)) << fullname;
495   return LocalName(x.type, x.id);
496 }
497 
498 /* static */
ParseLocalName(StringPiece name,ParsedName * p)499 bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
500   if (!ConsumeDeviceType(&name, &p->type)) {
501     return false;
502   }
503   p->has_type = true;
504   if (!absl::ConsumePrefix(&name, ":")) {
505     return false;
506   }
507   if (!ConsumeNumber(&name, &p->id)) {
508     return false;
509   }
510   p->has_id = true;
511   return name.empty();
512 }
513 
514 /* static */
SplitDeviceName(StringPiece name,string * task,string * device)515 bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
516                                       string* device) {
517   ParsedName pn;
518   if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
519     task->clear();
520     task->reserve(
521         (pn.has_job ? (5 + pn.job.size()) : 0) +
522         (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
523         (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
524     if (pn.has_job) {
525       strings::StrAppend(task, "/job:", pn.job);
526     }
527     if (pn.has_replica) {
528       strings::StrAppend(task, "/replica:", pn.replica);
529     }
530     if (pn.has_task) {
531       strings::StrAppend(task, "/task:", pn.task);
532     }
533     device->clear();
534     strings::StrAppend(device, pn.type, ":", pn.id);
535     return true;
536   }
537   return false;
538 }
539 
540 /* static */
GetTaskName(const ParsedName & pn,string * task)541 bool DeviceNameUtils::GetTaskName(const ParsedName& pn, string* task) {
542   if (pn.has_job && pn.has_replica && pn.has_task) {
543     task->clear();
544     task->reserve((5 + pn.job.size()) +
545                   (9 + 4 /*estimated UB for # replica digits*/) +
546                   (6 + 4 /*estimated UB for # task digits*/));
547     strings::StrAppend(task, "/job:", pn.job);
548     strings::StrAppend(task, "/replica:", pn.replica);
549     strings::StrAppend(task, "/task:", pn.task);
550     return true;
551   }
552   return false;
553 }
554 
GetNamesForDeviceMappings(const ParsedName & pn)555 std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
556     const ParsedName& pn) {
557   if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
558     return {
559         DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
560         LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
561   } else {
562     return {};
563   }
564 }
565 
GetLocalNamesForDeviceMappings(const ParsedName & pn)566 std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
567     const ParsedName& pn) {
568   if (pn.has_type && pn.has_id) {
569     return {DeviceNameUtils::LocalName(pn.type, pn.id),
570             LegacyLocalName(pn.type, pn.id)};
571   } else {
572     return {};
573   }
574 }
575 
DeviceNameToCpuDeviceName(const string & device_name,string * host_device_name)576 /*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName(
577     const string& device_name, string* host_device_name) {
578   DeviceNameUtils::ParsedName device;
579   if (!DeviceNameUtils::ParseFullName(device_name, &device)) {
580     return errors::Internal("Could not parse device name ", device_name);
581   }
582   device.type = "CPU";
583   device.id = 0;
584   *host_device_name = DeviceNameUtils::ParsedNameToString(device);
585   return Status::OK();
586 }
587 
operator <<(std::ostream & os,const DeviceNameUtils::ParsedName & x)588 std::ostream& operator<<(std::ostream& os,
589                          const DeviceNameUtils::ParsedName& x) {
590   os << DeviceNameUtils::ParsedNameToString(x);
591   return os;
592 }
593 
594 }  // namespace tensorflow
595