1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/device_name_utils.h"
17
18 #include <algorithm>
19
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 #include "tensorflow/core/platform/logging.h"
24
25 namespace tensorflow {
26
IsAlpha(char c)27 static bool IsAlpha(char c) {
28 return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
29 }
30
IsAlphaNumOrUnderscore(char c)31 static bool IsAlphaNumOrUnderscore(char c) {
32 return IsAlpha(c) || (c >= '0' && c <= '9') || c == '_';
33 }
34
35 // Returns true iff "in" is a valid job name.
IsJobName(StringPiece in)36 static bool IsJobName(StringPiece in) {
37 return !in.empty() && IsAlpha(in.front()) &&
38 std::all_of(in.begin(), in.end(), IsAlphaNumOrUnderscore);
39 }
40
ConsumePrefix(StringPiece * in,string * out,StringPiece prefix_terminators)41 static bool ConsumePrefix(StringPiece* in, string* out,
42 StringPiece prefix_terminators) {
43 if (in->empty() || !IsAlpha(in->front())) return false;
44 const auto end_it =
45 std::find_first_of(in->begin(), in->end(), prefix_terminators.begin(),
46 prefix_terminators.end());
47 if (!std::all_of(in->begin(), end_it, IsAlphaNumOrUnderscore)) {
48 return false;
49 }
50 out->assign(in->begin(), end_it);
51 in->remove_prefix(end_it - in->begin());
52 return true;
53 }
54
55 // Returns true and fills in "*job" iff "*in" starts with a job name.
ConsumeJobName(StringPiece * in,string * job)56 static bool ConsumeJobName(StringPiece* in, string* job) {
57 return ConsumePrefix(in, job, "/");
58 }
59
60 // Returns true and fills in "*device_type" iff "*in" starts with a device type
61 // name.
ConsumeDeviceType(StringPiece * in,string * device_type)62 static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
63 return ConsumePrefix(in, device_type, "/:");
64 }
65
66 // Returns true and fills in "*val" iff "*in" starts with a decimal
67 // number.
ConsumeNumber(StringPiece * in,int * val)68 static bool ConsumeNumber(StringPiece* in, int* val) {
69 uint64 tmp;
70 if (str_util::ConsumeLeadingDigits(in, &tmp)) {
71 *val = tmp;
72 return true;
73 } else {
74 return false;
75 }
76 }
77
78 // Returns a fully qualified device name given the parameters.
DeviceName(const string & job,int replica,int task,const string & device_prefix,const string & device_type,int id)79 static string DeviceName(const string& job, int replica, int task,
80 const string& device_prefix, const string& device_type,
81 int id) {
82 CHECK(IsJobName(job)) << job;
83 CHECK_LE(0, replica);
84 CHECK_LE(0, task);
85 CHECK(!device_type.empty());
86 CHECK_LE(0, id);
87 return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
88 device_prefix, device_type, ":", id);
89 }
90
91 /* static */
FullName(const string & job,int replica,int task,const string & type,int id)92 string DeviceNameUtils::FullName(const string& job, int replica, int task,
93 const string& type, int id) {
94 return DeviceName(job, replica, task, "/device:", type, id);
95 }
96
97 namespace {
LegacyName(const string & job,int replica,int task,const string & type,int id)98 string LegacyName(const string& job, int replica, int task, const string& type,
99 int id) {
100 return DeviceName(job, replica, task, "/", absl::AsciiStrToLower(type), id);
101 }
102 } // anonymous namespace
103
ParseFullName(StringPiece fullname,ParsedName * p)104 bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
105 p->Clear();
106 if (fullname == "/") {
107 return true;
108 }
109 while (!fullname.empty()) {
110 bool progress = false;
111 if (absl::ConsumePrefix(&fullname, "/job:")) {
112 p->has_job = !absl::ConsumePrefix(&fullname, "*");
113 if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
114 return false;
115 }
116 progress = true;
117 }
118 if (absl::ConsumePrefix(&fullname, "/replica:")) {
119 p->has_replica = !absl::ConsumePrefix(&fullname, "*");
120 if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
121 return false;
122 }
123 progress = true;
124 }
125 if (absl::ConsumePrefix(&fullname, "/task:")) {
126 p->has_task = !absl::ConsumePrefix(&fullname, "*");
127 if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
128 return false;
129 }
130 progress = true;
131 }
132 if (absl::ConsumePrefix(&fullname, "/device:")) {
133 p->has_type = !absl::ConsumePrefix(&fullname, "*");
134 if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
135 return false;
136 }
137 if (!absl::ConsumePrefix(&fullname, ":")) {
138 p->has_id = false;
139 } else {
140 p->has_id = !absl::ConsumePrefix(&fullname, "*");
141 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
142 return false;
143 }
144 }
145 progress = true;
146 }
147
148 // Handle legacy naming convention for cpu and gpu.
149 if (absl::ConsumePrefix(&fullname, "/cpu:") ||
150 absl::ConsumePrefix(&fullname, "/CPU:")) {
151 p->has_type = true;
152 p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
153 p->has_id = !absl::ConsumePrefix(&fullname, "*");
154 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
155 return false;
156 }
157 progress = true;
158 }
159 if (absl::ConsumePrefix(&fullname, "/gpu:") ||
160 absl::ConsumePrefix(&fullname, "/GPU:")) {
161 p->has_type = true;
162 p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
163 p->has_id = !absl::ConsumePrefix(&fullname, "*");
164 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
165 return false;
166 }
167 progress = true;
168 }
169
170 if (!progress) {
171 return false;
172 }
173 }
174 return true;
175 }
176
177 namespace {
178
CompleteName(const DeviceNameUtils::ParsedName & parsed_basename,DeviceNameUtils::ParsedName * parsed_name)179 void CompleteName(const DeviceNameUtils::ParsedName& parsed_basename,
180 DeviceNameUtils::ParsedName* parsed_name) {
181 if (!parsed_name->has_job) {
182 parsed_name->job = parsed_basename.job;
183 parsed_name->has_job = true;
184 }
185 if (!parsed_name->has_replica) {
186 parsed_name->replica = parsed_basename.replica;
187 parsed_name->has_replica = true;
188 }
189 if (!parsed_name->has_task) {
190 parsed_name->task = parsed_basename.task;
191 parsed_name->has_task = true;
192 }
193 if (!parsed_name->has_type) {
194 parsed_name->type = parsed_basename.type;
195 parsed_name->has_type = true;
196 }
197 if (!parsed_name->has_id) {
198 parsed_name->id = parsed_basename.id;
199 parsed_name->has_id = true;
200 }
201 }
202
203 } // namespace
204
205 /* static */
CanonicalizeDeviceName(StringPiece fullname,StringPiece basename,string * canonical_name)206 Status DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname,
207 StringPiece basename,
208 string* canonical_name) {
209 *canonical_name = "";
210 ParsedName parsed_basename;
211 if (!ParseFullName(basename, &parsed_basename)) {
212 return errors::InvalidArgument("Could not parse basename: ", basename,
213 " into a device specification.");
214 }
215 if (!(parsed_basename.has_job && parsed_basename.has_replica &&
216 parsed_basename.has_task && parsed_basename.has_type &&
217 parsed_basename.has_id)) {
218 return errors::InvalidArgument("Basename: ", basename,
219 " should be fully "
220 "specified.");
221 }
222 ParsedName parsed_name;
223 if (ParseLocalName(fullname, &parsed_name)) {
224 CompleteName(parsed_basename, &parsed_name);
225 *canonical_name = ParsedNameToString(parsed_name);
226 return Status::OK();
227 }
228 if (ParseFullName(fullname, &parsed_name)) {
229 CompleteName(parsed_basename, &parsed_name);
230 *canonical_name = ParsedNameToString(parsed_name);
231 return Status::OK();
232 }
233 return errors::InvalidArgument("Could not parse ", fullname,
234 " into a device "
235 "specification.");
236 }
237
238 /* static */
ParsedNameToString(const ParsedName & pn)239 string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
240 string buf;
241 if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
242 if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
243 if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
244 if (pn.has_type) {
245 strings::StrAppend(&buf, "/device:", pn.type, ":");
246 if (pn.has_id) {
247 strings::StrAppend(&buf, pn.id);
248 } else {
249 strings::StrAppend(&buf, "*");
250 }
251 }
252 return buf;
253 }
254
255 /* static */
IsSpecification(const ParsedName & less_specific,const ParsedName & more_specific)256 bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
257 const ParsedName& more_specific) {
258 if (less_specific.has_job &&
259 (!more_specific.has_job || (less_specific.job != more_specific.job))) {
260 return false;
261 }
262 if (less_specific.has_replica &&
263 (!more_specific.has_replica ||
264 (less_specific.replica != more_specific.replica))) {
265 return false;
266 }
267 if (less_specific.has_task &&
268 (!more_specific.has_task || (less_specific.task != more_specific.task))) {
269 return false;
270 }
271 if (less_specific.has_type &&
272 (!more_specific.has_type || (less_specific.type != more_specific.type))) {
273 return false;
274 }
275 if (less_specific.has_id &&
276 (!more_specific.has_id || (less_specific.id != more_specific.id))) {
277 return false;
278 }
279 return true;
280 }
281
282 /* static */
AreCompatibleDevNames(const ParsedName & a,const ParsedName & b)283 bool DeviceNameUtils::AreCompatibleDevNames(const ParsedName& a,
284 const ParsedName& b) {
285 if (a.has_job && b.has_job && (a.job != b.job)) {
286 return false;
287 }
288 if (a.has_replica && b.has_replica && (a.replica != b.replica)) {
289 return false;
290 }
291 if (a.has_task && b.has_task && (a.task != b.task)) {
292 return false;
293 }
294 if (a.has_type && b.has_type && (a.type != b.type)) {
295 return false;
296 }
297 if (a.has_id && b.has_id && (a.id != b.id)) {
298 return false;
299 }
300 return true;
301 }
302
EnsureSpecification(ParsedName * more_specific,const ParsedName & less_specific)303 void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
304 const ParsedName& less_specific) {
305 if (less_specific.has_job) {
306 more_specific->has_job = true;
307 more_specific->job = less_specific.job;
308 }
309 if (less_specific.has_replica) {
310 more_specific->has_replica = true;
311 more_specific->replica = less_specific.replica;
312 }
313 if (less_specific.has_task) {
314 more_specific->has_task = true;
315 more_specific->task = less_specific.task;
316 }
317 if (less_specific.has_type) {
318 more_specific->has_type = true;
319 more_specific->type = less_specific.type;
320 }
321 if (less_specific.has_id) {
322 more_specific->has_id = true;
323 more_specific->id = less_specific.id;
324 }
325 }
326
327 /* static */
IsCompleteSpecification(const ParsedName & pattern,const ParsedName & name)328 bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
329 const ParsedName& name) {
330 CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
331 name.has_id);
332
333 if (pattern.has_job && (pattern.job != name.job)) return false;
334 if (pattern.has_replica && (pattern.replica != name.replica)) return false;
335 if (pattern.has_task && (pattern.task != name.task)) return false;
336 if (pattern.has_type && (pattern.type != name.type)) return false;
337 if (pattern.has_id && (pattern.id != name.id)) return false;
338 return true;
339 }
340
341 namespace {
MergeDevNamesImpl(DeviceNameUtils::ParsedName * target,const DeviceNameUtils::ParsedName & other,bool allow_soft_placement,bool override_conflicts)342 Status MergeDevNamesImpl(DeviceNameUtils::ParsedName* target,
343 const DeviceNameUtils::ParsedName& other,
344 bool allow_soft_placement, bool override_conflicts) {
345 const auto& ParsedNameToString = DeviceNameUtils::ParsedNameToString;
346 if (other.has_job) {
347 if (target->has_job && target->job != other.job) {
348 return errors::InvalidArgument(
349 "Cannot merge devices with incompatible jobs: '",
350 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
351 "'");
352 } else {
353 target->has_job = other.has_job;
354 target->job = other.job;
355 }
356 }
357
358 if (other.has_replica) {
359 if (target->has_replica && target->replica != other.replica) {
360 return errors::InvalidArgument(
361 "Cannot merge devices with incompatible replicas: '",
362 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
363 "'");
364 } else {
365 target->has_replica = other.has_replica;
366 target->replica = other.replica;
367 }
368 }
369
370 if (other.has_task) {
371 if (target->has_task && target->task != other.task) {
372 return errors::InvalidArgument(
373 "Cannot merge devices with incompatible tasks: '",
374 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
375 "'");
376 } else {
377 target->has_task = other.has_task;
378 target->task = other.task;
379 }
380 }
381
382 if (other.has_type) {
383 if (target->has_type && target->type != other.type) {
384 if (!allow_soft_placement) {
385 return errors::InvalidArgument(
386 "Cannot merge devices with incompatible types: '",
387 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
388 "'");
389 } else if (override_conflicts) {
390 target->type = other.type;
391 } else {
392 target->has_id = false;
393 target->has_type = false;
394 return Status::OK();
395 }
396 } else {
397 target->has_type = other.has_type;
398 target->type = other.type;
399 }
400 }
401
402 if (other.has_id) {
403 if (target->has_id && target->id != other.id) {
404 if (!allow_soft_placement) {
405 return errors::InvalidArgument(
406 "Cannot merge devices with incompatible ids: '",
407 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
408 "'");
409 } else if (override_conflicts) {
410 target->id = other.id;
411 } else {
412 target->has_id = false;
413 return Status::OK();
414 }
415 } else {
416 target->has_id = other.has_id;
417 target->id = other.id;
418 }
419 }
420
421 return Status::OK();
422 }
423
424 } // namespace
425
426 /* static */
MergeDevNames(ParsedName * target,const ParsedName & other,bool allow_soft_placement)427 Status DeviceNameUtils::MergeDevNames(ParsedName* target,
428 const ParsedName& other,
429 bool allow_soft_placement) {
430 return MergeDevNamesImpl(target, other, allow_soft_placement,
431 /*override_conflicts=*/false);
432 }
433
434 /* static */
MergeOverrideDevNames(ParsedName * target,const ParsedName & other)435 Status DeviceNameUtils::MergeOverrideDevNames(ParsedName* target,
436 const ParsedName& other) {
437 return MergeDevNamesImpl(target, other, /*allow_soft_placement=*/true,
438 /*override_conflicts=*/true);
439 }
440
441 /* static */
IsSameAddressSpace(const ParsedName & a,const ParsedName & b)442 bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
443 const ParsedName& b) {
444 return (a.has_job && b.has_job && (a.job == b.job)) &&
445 (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
446 (a.has_task && b.has_task && (a.task == b.task));
447 }
448
449 /* static */
IsSameAddressSpace(StringPiece src,StringPiece dst)450 bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
451 ParsedName x;
452 ParsedName y;
453 return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
454 IsSameAddressSpace(x, y);
455 }
456
457 /* static */
IsDifferentAddressSpace(const ParsedName & a,const ParsedName & b)458 bool DeviceNameUtils::IsDifferentAddressSpace(const ParsedName& a,
459 const ParsedName& b) {
460 return (a.has_job && b.has_job && (a.job != b.job)) ||
461 (a.has_replica && b.has_replica && (a.replica != b.replica)) ||
462 (a.has_task && b.has_task && (a.task != b.task));
463 }
464
465 /* static */
AddressSpace(const ParsedName & name)466 const DeviceNameUtils::ParsedName DeviceNameUtils::AddressSpace(
467 const ParsedName& name) {
468 ParsedName address_space;
469 address_space.has_job = name.has_job;
470 address_space.has_replica = name.has_replica;
471 address_space.has_task = name.has_task;
472 address_space.job = name.job;
473 address_space.replica = name.replica;
474 address_space.task = name.task;
475 return address_space;
476 }
477
478 /* static */
LocalName(StringPiece type,int id)479 string DeviceNameUtils::LocalName(StringPiece type, int id) {
480 return strings::StrCat("/device:", type, ":", id);
481 }
482
483 namespace {
484 // Returns the legacy local device name given its "type" and "id" (which is
485 // '/device:type:id').
LegacyLocalName(StringPiece type,int id)486 string LegacyLocalName(StringPiece type, int id) {
487 return strings::StrCat(type, ":", id);
488 }
489 } // anonymous namespace
490
491 /* static */
LocalName(StringPiece fullname)492 string DeviceNameUtils::LocalName(StringPiece fullname) {
493 ParsedName x;
494 CHECK(ParseFullName(fullname, &x)) << fullname;
495 return LocalName(x.type, x.id);
496 }
497
498 /* static */
ParseLocalName(StringPiece name,ParsedName * p)499 bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
500 if (!ConsumeDeviceType(&name, &p->type)) {
501 return false;
502 }
503 p->has_type = true;
504 if (!absl::ConsumePrefix(&name, ":")) {
505 return false;
506 }
507 if (!ConsumeNumber(&name, &p->id)) {
508 return false;
509 }
510 p->has_id = true;
511 return name.empty();
512 }
513
514 /* static */
SplitDeviceName(StringPiece name,string * task,string * device)515 bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
516 string* device) {
517 ParsedName pn;
518 if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
519 task->clear();
520 task->reserve(
521 (pn.has_job ? (5 + pn.job.size()) : 0) +
522 (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
523 (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
524 if (pn.has_job) {
525 strings::StrAppend(task, "/job:", pn.job);
526 }
527 if (pn.has_replica) {
528 strings::StrAppend(task, "/replica:", pn.replica);
529 }
530 if (pn.has_task) {
531 strings::StrAppend(task, "/task:", pn.task);
532 }
533 device->clear();
534 strings::StrAppend(device, pn.type, ":", pn.id);
535 return true;
536 }
537 return false;
538 }
539
540 /* static */
GetTaskName(const ParsedName & pn,string * task)541 bool DeviceNameUtils::GetTaskName(const ParsedName& pn, string* task) {
542 if (pn.has_job && pn.has_replica && pn.has_task) {
543 task->clear();
544 task->reserve((5 + pn.job.size()) +
545 (9 + 4 /*estimated UB for # replica digits*/) +
546 (6 + 4 /*estimated UB for # task digits*/));
547 strings::StrAppend(task, "/job:", pn.job);
548 strings::StrAppend(task, "/replica:", pn.replica);
549 strings::StrAppend(task, "/task:", pn.task);
550 return true;
551 }
552 return false;
553 }
554
GetNamesForDeviceMappings(const ParsedName & pn)555 std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
556 const ParsedName& pn) {
557 if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
558 return {
559 DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
560 LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
561 } else {
562 return {};
563 }
564 }
565
GetLocalNamesForDeviceMappings(const ParsedName & pn)566 std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
567 const ParsedName& pn) {
568 if (pn.has_type && pn.has_id) {
569 return {DeviceNameUtils::LocalName(pn.type, pn.id),
570 LegacyLocalName(pn.type, pn.id)};
571 } else {
572 return {};
573 }
574 }
575
DeviceNameToCpuDeviceName(const string & device_name,string * host_device_name)576 /*static*/ Status DeviceNameUtils::DeviceNameToCpuDeviceName(
577 const string& device_name, string* host_device_name) {
578 DeviceNameUtils::ParsedName device;
579 if (!DeviceNameUtils::ParseFullName(device_name, &device)) {
580 return errors::Internal("Could not parse device name ", device_name);
581 }
582 device.type = "CPU";
583 device.id = 0;
584 *host_device_name = DeviceNameUtils::ParsedNameToString(device);
585 return Status::OK();
586 }
587
operator <<(std::ostream & os,const DeviceNameUtils::ParsedName & x)588 std::ostream& operator<<(std::ostream& os,
589 const DeviceNameUtils::ParsedName& x) {
590 os << DeviceNameUtils::ParsedNameToString(x);
591 return os;
592 }
593
594 } // namespace tensorflow
595