• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/device_name_utils.h"
17 
18 #include <algorithm>
19 
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace tensorflow {
26 
IsAlpha(char c)27 static bool IsAlpha(char c) {
28   return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
29 }
30 
IsAlphaNumOrUnderscore(char c)31 static bool IsAlphaNumOrUnderscore(char c) {
32   return IsAlpha(c) || (c >= '0' && c <= '9') || c == '_';
33 }
34 
35 // Returns true iff "in" is a valid job name.
IsJobName(StringPiece in)36 static bool IsJobName(StringPiece in) {
37   return !in.empty() && IsAlpha(in.front()) &&
38          std::all_of(in.begin(), in.end(), IsAlphaNumOrUnderscore);
39 }
40 
ConsumePrefix(StringPiece * in,string * out,StringPiece prefix_terminators)41 static bool ConsumePrefix(StringPiece* in, string* out,
42                           StringPiece prefix_terminators) {
43   if (in->empty() || !IsAlpha(in->front())) return false;
44   const auto end_it =
45       std::find_first_of(in->begin(), in->end(), prefix_terminators.begin(),
46                          prefix_terminators.end());
47   if (!std::all_of(in->begin(), end_it, IsAlphaNumOrUnderscore)) {
48     return false;
49   }
50   out->assign(in->begin(), end_it);
51   in->remove_prefix(end_it - in->begin());
52   return true;
53 }
54 
55 // Returns true and fills in "*job" iff "*in" starts with a job name.
ConsumeJobName(StringPiece * in,string * job)56 static bool ConsumeJobName(StringPiece* in, string* job) {
57   return ConsumePrefix(in, job, "/");
58 }
59 
60 // Returns true and fills in "*device_type" iff "*in" starts with a device type
61 // name.
ConsumeDeviceType(StringPiece * in,string * device_type)62 static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
63   return ConsumePrefix(in, device_type, "/:");
64 }
65 
66 // Returns true and fills in "*val" iff "*in" starts with a decimal
67 // number.
ConsumeNumber(StringPiece * in,int * val)68 static bool ConsumeNumber(StringPiece* in, int* val) {
69   uint64 tmp;
70   if (str_util::ConsumeLeadingDigits(in, &tmp)) {
71     *val = tmp;
72     return true;
73   } else {
74     return false;
75   }
76 }
77 
78 // Returns a fully qualified device name given the parameters.
DeviceName(const string & job,int replica,int task,const string & device_prefix,const string & device_type,int id)79 static string DeviceName(const string& job, int replica, int task,
80                          const string& device_prefix, const string& device_type,
81                          int id) {
82   CHECK(IsJobName(job)) << job;
83   CHECK_LE(0, replica);
84   CHECK_LE(0, task);
85   CHECK(!device_type.empty());
86   CHECK_LE(0, id);
87   return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
88                          device_prefix, device_type, ":", id);
89 }
90 
91 /* static */
FullName(const string & job,int replica,int task,const string & type,int id)92 string DeviceNameUtils::FullName(const string& job, int replica, int task,
93                                  const string& type, int id) {
94   return DeviceName(job, replica, task, "/device:", type, id);
95 }
96 
97 namespace {
LegacyName(const string & job,int replica,int task,const string & type,int id)98 string LegacyName(const string& job, int replica, int task, const string& type,
99                   int id) {
100   return DeviceName(job, replica, task, "/", absl::AsciiStrToLower(type), id);
101 }
102 }  // anonymous namespace
103 
ParseFullName(StringPiece fullname,ParsedName * p)104 bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
105   p->Clear();
106   if (fullname == "/") {
107     return true;
108   }
109   while (!fullname.empty()) {
110     bool progress = false;
111     if (absl::ConsumePrefix(&fullname, "/job:")) {
112       p->has_job = !absl::ConsumePrefix(&fullname, "*");
113       if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
114         return false;
115       }
116       progress = true;
117     }
118     if (absl::ConsumePrefix(&fullname, "/replica:")) {
119       p->has_replica = !absl::ConsumePrefix(&fullname, "*");
120       if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
121         return false;
122       }
123       progress = true;
124     }
125     if (absl::ConsumePrefix(&fullname, "/task:")) {
126       p->has_task = !absl::ConsumePrefix(&fullname, "*");
127       if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
128         return false;
129       }
130       progress = true;
131     }
132     if (absl::ConsumePrefix(&fullname, "/device:")) {
133       p->has_type = !absl::ConsumePrefix(&fullname, "*");
134       if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
135         return false;
136       }
137       if (!absl::ConsumePrefix(&fullname, ":")) {
138         p->has_id = false;
139       } else {
140         p->has_id = !absl::ConsumePrefix(&fullname, "*");
141         if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
142           return false;
143         }
144       }
145       progress = true;
146     }
147 
148     // Handle legacy naming convention for cpu and gpu.
149     if (absl::ConsumePrefix(&fullname, "/cpu:") ||
150         absl::ConsumePrefix(&fullname, "/CPU:")) {
151       p->has_type = true;
152       p->type = "CPU";  // Treat '/cpu:..' as uppercase '/device:CPU:...'
153       p->has_id = !absl::ConsumePrefix(&fullname, "*");
154       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
155         return false;
156       }
157       progress = true;
158     }
159     if (absl::ConsumePrefix(&fullname, "/gpu:") ||
160         absl::ConsumePrefix(&fullname, "/GPU:")) {
161       p->has_type = true;
162       p->type = "GPU";  // Treat '/gpu:..' as uppercase '/device:GPU:...'
163       p->has_id = !absl::ConsumePrefix(&fullname, "*");
164       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
165         return false;
166       }
167       progress = true;
168     }
169 
170     if (!progress) {
171       return false;
172     }
173   }
174   return true;
175 }
176 
ParseFullOrLocalName(StringPiece fullname,ParsedName * p)177 bool DeviceNameUtils::ParseFullOrLocalName(StringPiece fullname,
178                                            ParsedName* p) {
179   return ParseFullName(fullname, p) || ParseLocalName(fullname, p);
180 }
181 
182 namespace {
183 
CompleteName(const DeviceNameUtils::ParsedName & parsed_basename,DeviceNameUtils::ParsedName * parsed_name)184 void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
185                   DeviceNameUtils::ParsedName* parsed_name) {
186   if (!parsed_name->has_job) {
187     parsed_name->job = parsed_basename.job;
188     parsed_name->has_job = true;
189   }
190   if (!parsed_name->has_replica) {
191     parsed_name->replica = parsed_basename.replica;
192     parsed_name->has_replica = true;
193   }
194   if (!parsed_name->has_task) {
195     parsed_name->task = parsed_basename.task;
196     parsed_name->has_task = true;
197   }
198   if (!parsed_name->has_type) {
199     parsed_name->type = parsed_basename.type;
200     parsed_name->has_type = true;
201   }
202   if (!parsed_name->has_id) {
203     parsed_name->id = parsed_basename.id;
204     parsed_name->has_id = true;
205   }
206 }
207 
208 }  // namespace
209 
210 /* static */
CanonicalizeDeviceName(StringPiece fullname,StringPiece basename,string * canonical_name)211 Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname,
212                                                StringPiece basename,
213                                                string* canonical_name) {
214   *canonical_name = "";
215   ParsedName parsed_basename;
216   if (!ParseFullName(basename, &parsed_basename)) {
217     return errors::InvalidArgument("Could not parse basename: ", basename,
218                                    " into a device specification.");
219   }
220   if (!(parsed_basename.has_job && parsed_basename.has_replica &&
221         parsed_basename.has_task && parsed_basename.has_type &&
222         parsed_basename.has_id)) {
223     return errors::InvalidArgument("Basename: ", basename,
224                                    " should be fully "
225                                    "specified.");
226   }
227   ParsedName parsed_name;
228   if (ParseLocalName(fullname, &parsed_name)) {
229     CompleteName(parsed_basename, &parsed_name);
230     *canonical_name = ParsedNameToString(parsed_name);
231     return Status::OK();
232   }
233   if (ParseFullName(fullname, &parsed_name)) {
234     CompleteName(parsed_basename, &parsed_name);
235     *canonical_name = ParsedNameToString(parsed_name);
236     return Status::OK();
237   }
238   return errors::InvalidArgument("Could not parse ", fullname,
239                                  " into a device "
240                                  "specification.");
241 }
242 
243 /* static */
ParsedNameToString(const ParsedName & pn)244 string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
245   string buf;
246   if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
247   if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
248   if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
249   if (pn.has_type) {
250     strings::StrAppend(&buf, "/device:", pn.type, ":");
251     if (pn.has_id) {
252       strings::StrAppend(&buf, pn.id);
253     } else {
254       strings::StrAppend(&buf, "*");
255     }
256   }
257   return buf;
258 }
259 
260 /* static */
IsSpecification(const ParsedName & less_specific,const ParsedName & more_specific)261 bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
262                                       const ParsedName& more_specific) {
263   if (less_specific.has_job &&
264       (!more_specific.has_job || (less_specific.job != more_specific.job))) {
265     return false;
266   }
267   if (less_specific.has_replica &&
268       (!more_specific.has_replica ||
269        (less_specific.replica != more_specific.replica))) {
270     return false;
271   }
272   if (less_specific.has_task &&
273       (!more_specific.has_task || (less_specific.task != more_specific.task))) {
274     return false;
275   }
276   if (less_specific.has_type &&
277       (!more_specific.has_type || (less_specific.type != more_specific.type))) {
278     return false;
279   }
280   if (less_specific.has_id &&
281       (!more_specific.has_id || (less_specific.id != more_specific.id))) {
282     return false;
283   }
284   return true;
285 }
286 
287 /* static */
AreCompatibleDevNames(const ParsedName & a,const ParsedName & b)288 bool DeviceNameUtils::AreCompatibleDevNames(const ParsedName& a,
289                                             const ParsedName& b) {
290   if (a.has_job && b.has_job && (a.job != b.job)) {
291     return false;
292   }
293   if (a.has_replica && b.has_replica && (a.replica != b.replica)) {
294     return false;
295   }
296   if (a.has_task && b.has_task && (a.task != b.task)) {
297     return false;
298   }
299   if (a.has_type && b.has_type && (a.type != b.type)) {
300     return false;
301   }
302   if (a.has_id && b.has_id && (a.id != b.id)) {
303     return false;
304   }
305   return true;
306 }
307 
EnsureSpecification(ParsedName * more_specific,const ParsedName & less_specific)308 void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
309                                           const ParsedName& less_specific) {
310   if (less_specific.has_job) {
311     more_specific->has_job = true;
312     more_specific->job = less_specific.job;
313   }
314   if (less_specific.has_replica) {
315     more_specific->has_replica = true;
316     more_specific->replica = less_specific.replica;
317   }
318   if (less_specific.has_task) {
319     more_specific->has_task = true;
320     more_specific->task = less_specific.task;
321   }
322   if (less_specific.has_type) {
323     more_specific->has_type = true;
324     more_specific->type = less_specific.type;
325   }
326   if (less_specific.has_id) {
327     more_specific->has_id = true;
328     more_specific->id = less_specific.id;
329   }
330 }
331 
332 /* static */
IsCompleteSpecification(const ParsedName & pattern,const ParsedName & name)333 bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
334                                               const ParsedName& name) {
335   CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
336         name.has_id);
337 
338   if (pattern.has_job && (pattern.job != name.job)) return false;
339   if (pattern.has_replica && (pattern.replica != name.replica)) return false;
340   if (pattern.has_task && (pattern.task != name.task)) return false;
341   if (pattern.has_type && (pattern.type != name.type)) return false;
342   if (pattern.has_id && (pattern.id != name.id)) return false;
343   return true;
344 }
345 
346 namespace {
MergeDevNamesImpl(DeviceNameUtils::ParsedName * target,const DeviceNameUtils::ParsedName & other,bool allow_soft_placement,bool override_conflicts)347 Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target,
348                          const DeviceNameUtils::ParsedName& other,
349                          bool allow_soft_placement, bool override_conflicts) {
350   const auto& ParsedNameToString = DeviceNameUtils::ParsedNameToString;
351   if (other.has_job) {
352     if (target->has_job && target->job != other.job) {
353       return errors::InvalidArgument(
354           "Cannot merge devices with incompatible jobs: '",
355           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
356           "'");
357     } else {
358       target->has_job = other.has_job;
359       target->job = other.job;
360     }
361   }
362 
363   if (other.has_replica) {
364     if (target->has_replica && target->replica != other.replica) {
365       return errors::InvalidArgument(
366           "Cannot merge devices with incompatible replicas: '",
367           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
368           "'");
369     } else {
370       target->has_replica = other.has_replica;
371       target->replica = other.replica;
372     }
373   }
374 
375   if (other.has_task) {
376     if (target->has_task && target->task != other.task) {
377       return errors::InvalidArgument(
378           "Cannot merge devices with incompatible tasks: '",
379           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
380           "'");
381     } else {
382       target->has_task = other.has_task;
383       target->task = other.task;
384     }
385   }
386 
387   if (other.has_type) {
388     if (target->has_type && target->type != other.type) {
389       if (!allow_soft_placement) {
390         return errors::InvalidArgument(
391             "Cannot merge devices with incompatible types: '",
392             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
393             "'");
394       } else if (override_conflicts) {
395         target->type = other.type;
396       } else {
397         target->has_id = false;
398         target->has_type = false;
399         return Status::OK();
400       }
401     } else {
402       target->has_type = other.has_type;
403       target->type = other.type;
404     }
405   }
406 
407   if (other.has_id) {
408     if (target->has_id && target->id != other.id) {
409       if (!allow_soft_placement) {
410         return errors::InvalidArgument(
411             "Cannot merge devices with incompatible ids: '",
412             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
413             "'");
414       } else if (override_conflicts) {
415         target->id = other.id;
416       } else {
417         target->has_id = false;
418         return Status::OK();
419       }
420     } else {
421       target->has_id = other.has_id;
422       target->id = other.id;
423     }
424   }
425 
426   return Status::OK();
427 }
428 
429 }  // namespace
430 
431 /* static */
MergeDevNames(ParsedName * target,const ParsedName & other,bool allow_soft_placement)432 Status DeviceNameUtils::MergeDevNames(ParsedName* target,
433                                       const ParsedName& other,
434                                       bool allow_soft_placement) {
435   return MergeDevNamesImpl(target, other, allow_soft_placement,
436                            /*override_conflicts=*/false);
437 }
438 
439 /* static */
MergeOverrideDevNames(ParsedName * target,const ParsedName & other)440 Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target,
441                                               const ParsedName& other) {
442   return MergeDevNamesImpl(target, other, /*allow_soft_placement=*/true,
443                            /*override_conflicts=*/true);
444 }
445 
446 /* static */
IsSameAddressSpace(const ParsedName & a,const ParsedName & b)447 bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
448                                          const ParsedName& b) {
449   return (a.has_job && b.has_job && (a.job == b.job)) &&
450          (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
451          (a.has_task && b.has_task && (a.task == b.task));
452 }
453 
454 /* static */
IsSameAddressSpace(StringPiece src,StringPiece dst)455 bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
456   ParsedName x;
457   ParsedName y;
458   return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
459          IsSameAddressSpace(x, y);
460 }
461 
462 /* static */
IsDifferentAddressSpace(const ParsedName & a,const ParsedName & b)463 bool DeviceNameUtils::IsDifferentAddressSpace(const ParsedName& a,
464                                               const ParsedName& b) {
465   return (a.has_job && b.has_job && (a.job != b.job)) ||
466          (a.has_replica && b.has_replica && (a.replica != b.replica)) ||
467          (a.has_task && b.has_task && (a.task != b.task));
468 }
469 
470 /* static */
AddressSpace(const ParsedName & name)471 const DeviceNameUtils::ParsedName DeviceNameUtils::AddressSpace(
472     const ParsedName& name) {
473   ParsedName address_space;
474   address_space.has_job = name.has_job;
475   address_space.has_replica = name.has_replica;
476   address_space.has_task = name.has_task;
477   address_space.job = name.job;
478   address_space.replica = name.replica;
479   address_space.task = name.task;
480   return address_space;
481 }
482 
483 /* static */
LocalName(StringPiece type,int id)484 string DeviceNameUtils::LocalName(StringPiece type, int id) {
485   return strings::StrCat("/device:", type, ":", id);
486 }
487 
488 namespace {
489 // Returns the legacy local device name given its "type" and "id" (which is
490 // '/device:type:id').
LegacyLocalName(StringPiece type,int id)491 string LegacyLocalName(StringPiece type, int id) {
492   return strings::StrCat(type, ":", id);
493 }
494 }  // anonymous namespace
495 
496 /* static */
LocalName(StringPiece fullname)497 string DeviceNameUtils::LocalName(StringPiece fullname) {
498   ParsedName x;
499   CHECK(ParseFullName(fullname, &x)) << fullname;
500   return LocalName(x.type, x.id);
501 }
502 
503 /* static */
ParseLocalName(StringPiece name,ParsedName * p)504 bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
505   if (!ConsumeDeviceType(&name, &p->type)) {
506     return false;
507   }
508   p->has_type = true;
509   if (!absl::ConsumePrefix(&name, ":")) {
510     return false;
511   }
512   if (!ConsumeNumber(&name, &p->id)) {
513     return false;
514   }
515   p->has_id = true;
516   return name.empty();
517 }
518 
519 /* static */
SplitDeviceName(StringPiece name,string * task,string * device)520 bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
521                                       string* device) {
522   ParsedName pn;
523   if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
524     task->clear();
525     task->reserve(
526         (pn.has_job ? (5 + pn.job.size()) : 0) +
527         (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
528         (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
529     if (pn.has_job) {
530       strings::StrAppend(task, "/job:", pn.job);
531     }
532     if (pn.has_replica) {
533       strings::StrAppend(task, "/replica:", pn.replica);
534     }
535     if (pn.has_task) {
536       strings::StrAppend(task, "/task:", pn.task);
537     }
538     device->clear();
539     strings::StrAppend(device, pn.type, ":", pn.id);
540     return true;
541   }
542   return false;
543 }
544 
545 /* static */
GetTaskName(const ParsedName & pn,string * task)546 bool DeviceNameUtils::GetTaskName(const ParsedName& pn, string* task) {
547   if (pn.has_job && pn.has_replica && pn.has_task) {
548     task->clear();
549     task->reserve((5 + pn.job.size()) +
550                   (9 + 4 /*estimated UB for # replica digits*/) +
551                   (6 + 4 /*estimated UB for # task digits*/));
552     strings::StrAppend(task, "/job:", pn.job);
553     strings::StrAppend(task, "/replica:", pn.replica);
554     strings::StrAppend(task, "/task:", pn.task);
555     return true;
556   }
557   return false;
558 }
559 
GetNamesForDeviceMappings(const ParsedName & pn)560 std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
561     const ParsedName& pn) {
562   if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
563     return {
564         DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
565         LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
566   } else {
567     return {};
568   }
569 }
570 
GetLocalNamesForDeviceMappings(const ParsedName & pn)571 std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
572     const ParsedName& pn) {
573   if (pn.has_type && pn.has_id) {
574     return {DeviceNameUtils::LocalName(pn.type, pn.id),
575             LegacyLocalName(pn.type, pn.id)};
576   } else {
577     return {};
578   }
579 }
580 
DeviceNameToCpuDeviceName(const string & device_name,string * host_device_name)581 /*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName(
582     const string& device_name, string* host_device_name) {
583   DeviceNameUtils::ParsedName device;
584   if (!DeviceNameUtils::ParseFullName(device_name, &device)) {
585     return errors::Internal("Could not parse device name ", device_name);
586   }
587   device.type = "CPU";
588   device.has_type = true;
589   device.id = 0;
590   device.has_id = true;
591   *host_device_name = DeviceNameUtils::ParsedNameToString(device);
592   return Status::OK();
593 }
594 
operator <<(std::ostream & os,const DeviceNameUtils::ParsedName & x)595 std::ostream& operator<<(std::ostream& os,
596                          const DeviceNameUtils::ParsedName& x) {
597   os << DeviceNameUtils::ParsedNameToString(x);
598   return os;
599 }
600 
601 }  // namespace tensorflow
602