• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/device_name_utils.h"
17 
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22 
23 namespace tensorflow {
24 
IsAlpha(char c)25 static bool IsAlpha(char c) {
26   return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
27 }
28 
IsAlphaNum(char c)29 static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
30 
31 // Returns true iff "in" is a valid job name.
IsJobName(StringPiece in)32 static bool IsJobName(StringPiece in) {
33   if (in.empty()) return false;
34   if (!IsAlpha(in[0])) return false;
35   for (size_t i = 1; i < in.size(); ++i) {
36     if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false;
37   }
38   return true;
39 }
40 
41 // Returns true and fills in "*job" iff "*in" starts with a job name.
ConsumeJobName(StringPiece * in,string * job)42 static bool ConsumeJobName(StringPiece* in, string* job) {
43   if (in->empty()) return false;
44   if (!IsAlpha((*in)[0])) return false;
45   size_t i = 1;
46   for (; i < in->size(); ++i) {
47     const char c = (*in)[i];
48     if (c == '/') break;
49     if (!(IsAlphaNum(c) || c == '_')) {
50       return false;
51     }
52   }
53   job->assign(in->data(), i);
54   in->remove_prefix(i);
55   return true;
56 }
57 
58 // Returns true and fills in "*device_type" iff "*in" starts with a device type
59 // name.
ConsumeDeviceType(StringPiece * in,string * device_type)60 static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
61   if (in->empty()) return false;
62   if (!IsAlpha((*in)[0])) return false;
63   size_t i = 1;
64   for (; i < in->size(); ++i) {
65     const char c = (*in)[i];
66     if (c == '/' || c == ':') break;
67     if (!(IsAlphaNum(c) || c == '_')) {
68       return false;
69     }
70   }
71   device_type->assign(in->data(), i);
72   in->remove_prefix(i);
73   return true;
74 }
75 
76 // Returns true and fills in "*val" iff "*in" starts with a decimal
77 // number.
ConsumeNumber(StringPiece * in,int * val)78 static bool ConsumeNumber(StringPiece* in, int* val) {
79   uint64 tmp;
80   if (str_util::ConsumeLeadingDigits(in, &tmp)) {
81     *val = tmp;
82     return true;
83   } else {
84     return false;
85   }
86 }
87 
88 // Returns a fully qualified device name given the parameters.
DeviceName(const string & job,int replica,int task,const string & device_prefix,const string & device_type,int id)89 static string DeviceName(const string& job, int replica, int task,
90                          const string& device_prefix, const string& device_type,
91                          int id) {
92   CHECK(IsJobName(job)) << job;
93   CHECK_LE(0, replica);
94   CHECK_LE(0, task);
95   CHECK(!device_type.empty());
96   CHECK_LE(0, id);
97   return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
98                          device_prefix, device_type, ":", id);
99 }
100 
101 /* static */
FullName(const string & job,int replica,int task,const string & type,int id)102 string DeviceNameUtils::FullName(const string& job, int replica, int task,
103                                  const string& type, int id) {
104   return DeviceName(job, replica, task, "/device:", type, id);
105 }
106 
107 namespace {
LegacyName(const string & job,int replica,int task,const string & type,int id)108 string LegacyName(const string& job, int replica, int task, const string& type,
109                   int id) {
110   return DeviceName(job, replica, task, "/", str_util::Lowercase(type), id);
111 }
112 }  // anonymous namespace
113 
ParseFullName(StringPiece fullname,ParsedName * p)114 bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
115   p->Clear();
116   if (fullname == "/") {
117     return true;
118   }
119   while (!fullname.empty()) {
120     bool progress = false;
121     if (str_util::ConsumePrefix(&fullname, "/job:")) {
122       p->has_job = !str_util::ConsumePrefix(&fullname, "*");
123       if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
124         return false;
125       }
126       progress = true;
127     }
128     if (str_util::ConsumePrefix(&fullname, "/replica:")) {
129       p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
130       if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
131         return false;
132       }
133       progress = true;
134     }
135     if (str_util::ConsumePrefix(&fullname, "/task:")) {
136       p->has_task = !str_util::ConsumePrefix(&fullname, "*");
137       if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
138         return false;
139       }
140       progress = true;
141     }
142     if (str_util::ConsumePrefix(&fullname, "/device:")) {
143       p->has_type = !str_util::ConsumePrefix(&fullname, "*");
144       if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
145         return false;
146       }
147       if (!str_util::ConsumePrefix(&fullname, ":")) {
148         p->has_id = false;
149       } else {
150         p->has_id = !str_util::ConsumePrefix(&fullname, "*");
151         if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
152           return false;
153         }
154       }
155       progress = true;
156     }
157 
158     // Handle legacy naming convention for cpu and gpu.
159     if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
160         str_util::ConsumePrefix(&fullname, "/CPU:")) {
161       p->has_type = true;
162       p->type = "CPU";  // Treat '/cpu:..' as uppercase '/device:CPU:...'
163       p->has_id = !str_util::ConsumePrefix(&fullname, "*");
164       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
165         return false;
166       }
167       progress = true;
168     }
169     if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
170         str_util::ConsumePrefix(&fullname, "/GPU:")) {
171       p->has_type = true;
172       p->type = "GPU";  // Treat '/gpu:..' as uppercase '/device:GPU:...'
173       p->has_id = !str_util::ConsumePrefix(&fullname, "*");
174       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
175         return false;
176       }
177       progress = true;
178     }
179 
180     if (!progress) {
181       return false;
182     }
183   }
184   return true;
185 }
186 
187 /* static */
CanonicalizeDeviceName(StringPiece fullname)188 string DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname) {
189   ParsedName parsed_name;
190   if (ParseLocalName(fullname, &parsed_name)) {
191     return ParsedNameToString(parsed_name);
192   }
193   if (ParseFullName(fullname, &parsed_name)) {
194     return ParsedNameToString(parsed_name);
195   }
196   return "";
197 }
198 
199 /* static */
ParsedNameToString(const ParsedName & pn)200 string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
201   string buf;
202   if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
203   if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
204   if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
205   if (pn.has_type) {
206     strings::StrAppend(&buf, "/device:", pn.type, ":");
207     if (pn.has_id) {
208       strings::StrAppend(&buf, pn.id);
209     } else {
210       strings::StrAppend(&buf, "*");
211     }
212   }
213   return buf;
214 }
215 
216 /* static */
IsSpecification(const ParsedName & less_specific,const ParsedName & more_specific)217 bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
218                                       const ParsedName& more_specific) {
219   if (less_specific.has_job &&
220       (!more_specific.has_job || (less_specific.job != more_specific.job))) {
221     return false;
222   }
223   if (less_specific.has_replica &&
224       (!more_specific.has_replica ||
225        (less_specific.replica != more_specific.replica))) {
226     return false;
227   }
228   if (less_specific.has_task &&
229       (!more_specific.has_task || (less_specific.task != more_specific.task))) {
230     return false;
231   }
232   if (less_specific.has_type &&
233       (!more_specific.has_type || (less_specific.type != more_specific.type))) {
234     return false;
235   }
236   if (less_specific.has_id &&
237       (!more_specific.has_id || (less_specific.id != more_specific.id))) {
238     return false;
239   }
240   return true;
241 }
242 
243 /* static */
IsCompleteSpecification(const ParsedName & pattern,const ParsedName & name)244 bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
245                                               const ParsedName& name) {
246   CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
247         name.has_id);
248 
249   if (pattern.has_job && (pattern.job != name.job)) return false;
250   if (pattern.has_replica && (pattern.replica != name.replica)) return false;
251   if (pattern.has_task && (pattern.task != name.task)) return false;
252   if (pattern.has_type && (pattern.type != name.type)) return false;
253   if (pattern.has_id && (pattern.id != name.id)) return false;
254   return true;
255 }
256 
257 /* static */
MergeDevNames(ParsedName * target,const ParsedName & other,bool allow_soft_placement)258 Status DeviceNameUtils::MergeDevNames(ParsedName* target,
259                                       const ParsedName& other,
260                                       bool allow_soft_placement) {
261   if (other.has_job) {
262     if (target->has_job && target->job != other.job) {
263       return errors::InvalidArgument(
264           "Cannot merge devices with incompatible jobs: '",
265           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
266           "'");
267     } else {
268       target->has_job = other.has_job;
269       target->job = other.job;
270     }
271   }
272 
273   if (other.has_replica) {
274     if (target->has_replica && target->replica != other.replica) {
275       return errors::InvalidArgument(
276           "Cannot merge devices with incompatible replicas: '",
277           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
278           "'");
279     } else {
280       target->has_replica = other.has_replica;
281       target->replica = other.replica;
282     }
283   }
284 
285   if (other.has_task) {
286     if (target->has_task && target->task != other.task) {
287       return errors::InvalidArgument(
288           "Cannot merge devices with incompatible tasks: '",
289           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
290           "'");
291     } else {
292       target->has_task = other.has_task;
293       target->task = other.task;
294     }
295   }
296 
297   if (other.has_type) {
298     if (target->has_type && target->type != other.type) {
299       if (!allow_soft_placement) {
300         return errors::InvalidArgument(
301             "Cannot merge devices with incompatible types: '",
302             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
303             "'");
304       } else {
305         target->has_id = false;
306         target->has_type = false;
307         return Status::OK();
308       }
309     } else {
310       target->has_type = other.has_type;
311       target->type = other.type;
312     }
313   }
314 
315   if (other.has_id) {
316     if (target->has_id && target->id != other.id) {
317       if (!allow_soft_placement) {
318         return errors::InvalidArgument(
319             "Cannot merge devices with incompatible ids: '",
320             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
321             "'");
322       } else {
323         target->has_id = false;
324         return Status::OK();
325       }
326     } else {
327       target->has_id = other.has_id;
328       target->id = other.id;
329     }
330   }
331 
332   return Status::OK();
333 }
334 
335 /* static */
IsSameAddressSpace(const ParsedName & a,const ParsedName & b)336 bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
337                                          const ParsedName& b) {
338   return (a.has_job && b.has_job && (a.job == b.job)) &&
339          (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
340          (a.has_task && b.has_task && (a.task == b.task));
341 }
342 
343 /* static */
IsSameAddressSpace(StringPiece src,StringPiece dst)344 bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
345   ParsedName x;
346   ParsedName y;
347   return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
348          IsSameAddressSpace(x, y);
349 }
350 
351 /* static */
LocalName(StringPiece type,int id)352 string DeviceNameUtils::LocalName(StringPiece type, int id) {
353   return strings::StrCat("/device:", type, ":", id);
354 }
355 
356 namespace {
357 // Returns the legacy local device name given its "type" and "id" (which is
358 // '/device:type:id').
LegacyLocalName(StringPiece type,int id)359 string LegacyLocalName(StringPiece type, int id) {
360   return strings::StrCat(type, ":", id);
361 }
362 }  // anonymous namespace
363 
364 /* static */
LocalName(StringPiece fullname)365 string DeviceNameUtils::LocalName(StringPiece fullname) {
366   ParsedName x;
367   CHECK(ParseFullName(fullname, &x)) << fullname;
368   return LocalName(x.type, x.id);
369 }
370 
371 /* static */
ParseLocalName(StringPiece name,ParsedName * p)372 bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
373   if (!ConsumeDeviceType(&name, &p->type)) {
374     return false;
375   }
376   p->has_type = true;
377   if (!str_util::ConsumePrefix(&name, ":")) {
378     return false;
379   }
380   if (!ConsumeNumber(&name, &p->id)) {
381     return false;
382   }
383   p->has_id = true;
384   return name.empty();
385 }
386 
387 /* static */
SplitDeviceName(StringPiece name,string * task,string * device)388 bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
389                                       string* device) {
390   ParsedName pn;
391   if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
392     task->clear();
393     task->reserve(
394         (pn.has_job ? (5 + pn.job.size()) : 0) +
395         (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
396         (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
397     if (pn.has_job) {
398       strings::StrAppend(task, "/job:", pn.job);
399     }
400     if (pn.has_replica) {
401       strings::StrAppend(task, "/replica:", pn.replica);
402     }
403     if (pn.has_task) {
404       strings::StrAppend(task, "/task:", pn.task);
405     }
406     device->clear();
407     strings::StrAppend(device, pn.type, ":", pn.id);
408     return true;
409   }
410   return false;
411 }
412 
GetNamesForDeviceMappings(const ParsedName & pn)413 std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
414     const ParsedName& pn) {
415   if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
416     return {
417         DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
418         LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
419   } else {
420     return {};
421   }
422 }
423 
GetLocalNamesForDeviceMappings(const ParsedName & pn)424 std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
425     const ParsedName& pn) {
426   if (pn.has_type && pn.has_id) {
427     return {DeviceNameUtils::LocalName(pn.type, pn.id),
428             LegacyLocalName(pn.type, pn.id)};
429   } else {
430     return {};
431   }
432 }
433 
434 }  // namespace tensorflow
435