1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/device_name_utils.h"
17
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 #include "tensorflow/core/platform/logging.h"
22
23 namespace tensorflow {
24
IsAlpha(char c)25 static bool IsAlpha(char c) {
26 return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
27 }
28
IsAlphaNum(char c)29 static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
30
31 // Returns true iff "in" is a valid job name.
IsJobName(StringPiece in)32 static bool IsJobName(StringPiece in) {
33 if (in.empty()) return false;
34 if (!IsAlpha(in[0])) return false;
35 for (size_t i = 1; i < in.size(); ++i) {
36 if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false;
37 }
38 return true;
39 }
40
41 // Returns true and fills in "*job" iff "*in" starts with a job name.
ConsumeJobName(StringPiece * in,string * job)42 static bool ConsumeJobName(StringPiece* in, string* job) {
43 if (in->empty()) return false;
44 if (!IsAlpha((*in)[0])) return false;
45 size_t i = 1;
46 for (; i < in->size(); ++i) {
47 const char c = (*in)[i];
48 if (c == '/') break;
49 if (!(IsAlphaNum(c) || c == '_')) {
50 return false;
51 }
52 }
53 job->assign(in->data(), i);
54 in->remove_prefix(i);
55 return true;
56 }
57
58 // Returns true and fills in "*device_type" iff "*in" starts with a device type
59 // name.
ConsumeDeviceType(StringPiece * in,string * device_type)60 static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
61 if (in->empty()) return false;
62 if (!IsAlpha((*in)[0])) return false;
63 size_t i = 1;
64 for (; i < in->size(); ++i) {
65 const char c = (*in)[i];
66 if (c == '/' || c == ':') break;
67 if (!(IsAlphaNum(c) || c == '_')) {
68 return false;
69 }
70 }
71 device_type->assign(in->data(), i);
72 in->remove_prefix(i);
73 return true;
74 }
75
76 // Returns true and fills in "*val" iff "*in" starts with a decimal
77 // number.
ConsumeNumber(StringPiece * in,int * val)78 static bool ConsumeNumber(StringPiece* in, int* val) {
79 uint64 tmp;
80 if (str_util::ConsumeLeadingDigits(in, &tmp)) {
81 *val = tmp;
82 return true;
83 } else {
84 return false;
85 }
86 }
87
88 // Returns a fully qualified device name given the parameters.
DeviceName(const string & job,int replica,int task,const string & device_prefix,const string & device_type,int id)89 static string DeviceName(const string& job, int replica, int task,
90 const string& device_prefix, const string& device_type,
91 int id) {
92 CHECK(IsJobName(job)) << job;
93 CHECK_LE(0, replica);
94 CHECK_LE(0, task);
95 CHECK(!device_type.empty());
96 CHECK_LE(0, id);
97 return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
98 device_prefix, device_type, ":", id);
99 }
100
101 /* static */
FullName(const string & job,int replica,int task,const string & type,int id)102 string DeviceNameUtils::FullName(const string& job, int replica, int task,
103 const string& type, int id) {
104 return DeviceName(job, replica, task, "/device:", type, id);
105 }
106
107 namespace {
LegacyName(const string & job,int replica,int task,const string & type,int id)108 string LegacyName(const string& job, int replica, int task, const string& type,
109 int id) {
110 return DeviceName(job, replica, task, "/", str_util::Lowercase(type), id);
111 }
112 } // anonymous namespace
113
ParseFullName(StringPiece fullname,ParsedName * p)114 bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
115 p->Clear();
116 if (fullname == "/") {
117 return true;
118 }
119 while (!fullname.empty()) {
120 bool progress = false;
121 if (str_util::ConsumePrefix(&fullname, "/job:")) {
122 p->has_job = !str_util::ConsumePrefix(&fullname, "*");
123 if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
124 return false;
125 }
126 progress = true;
127 }
128 if (str_util::ConsumePrefix(&fullname, "/replica:")) {
129 p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
130 if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
131 return false;
132 }
133 progress = true;
134 }
135 if (str_util::ConsumePrefix(&fullname, "/task:")) {
136 p->has_task = !str_util::ConsumePrefix(&fullname, "*");
137 if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
138 return false;
139 }
140 progress = true;
141 }
142 if (str_util::ConsumePrefix(&fullname, "/device:")) {
143 p->has_type = !str_util::ConsumePrefix(&fullname, "*");
144 if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
145 return false;
146 }
147 if (!str_util::ConsumePrefix(&fullname, ":")) {
148 p->has_id = false;
149 } else {
150 p->has_id = !str_util::ConsumePrefix(&fullname, "*");
151 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
152 return false;
153 }
154 }
155 progress = true;
156 }
157
158 // Handle legacy naming convention for cpu and gpu.
159 if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
160 str_util::ConsumePrefix(&fullname, "/CPU:")) {
161 p->has_type = true;
162 p->type = "CPU"; // Treat '/cpu:..' as uppercase '/device:CPU:...'
163 p->has_id = !str_util::ConsumePrefix(&fullname, "*");
164 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
165 return false;
166 }
167 progress = true;
168 }
169 if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
170 str_util::ConsumePrefix(&fullname, "/GPU:")) {
171 p->has_type = true;
172 p->type = "GPU"; // Treat '/gpu:..' as uppercase '/device:GPU:...'
173 p->has_id = !str_util::ConsumePrefix(&fullname, "*");
174 if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
175 return false;
176 }
177 progress = true;
178 }
179
180 if (!progress) {
181 return false;
182 }
183 }
184 return true;
185 }
186
187 /* static */
CanonicalizeDeviceName(StringPiece fullname)188 string DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname) {
189 ParsedName parsed_name;
190 if (ParseLocalName(fullname, &parsed_name)) {
191 return ParsedNameToString(parsed_name);
192 }
193 if (ParseFullName(fullname, &parsed_name)) {
194 return ParsedNameToString(parsed_name);
195 }
196 return "";
197 }
198
199 /* static */
ParsedNameToString(const ParsedName & pn)200 string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
201 string buf;
202 if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
203 if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
204 if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
205 if (pn.has_type) {
206 strings::StrAppend(&buf, "/device:", pn.type, ":");
207 if (pn.has_id) {
208 strings::StrAppend(&buf, pn.id);
209 } else {
210 strings::StrAppend(&buf, "*");
211 }
212 }
213 return buf;
214 }
215
216 /* static */
IsSpecification(const ParsedName & less_specific,const ParsedName & more_specific)217 bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
218 const ParsedName& more_specific) {
219 if (less_specific.has_job &&
220 (!more_specific.has_job || (less_specific.job != more_specific.job))) {
221 return false;
222 }
223 if (less_specific.has_replica &&
224 (!more_specific.has_replica ||
225 (less_specific.replica != more_specific.replica))) {
226 return false;
227 }
228 if (less_specific.has_task &&
229 (!more_specific.has_task || (less_specific.task != more_specific.task))) {
230 return false;
231 }
232 if (less_specific.has_type &&
233 (!more_specific.has_type || (less_specific.type != more_specific.type))) {
234 return false;
235 }
236 if (less_specific.has_id &&
237 (!more_specific.has_id || (less_specific.id != more_specific.id))) {
238 return false;
239 }
240 return true;
241 }
242
243 /* static */
IsCompleteSpecification(const ParsedName & pattern,const ParsedName & name)244 bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
245 const ParsedName& name) {
246 CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
247 name.has_id);
248
249 if (pattern.has_job && (pattern.job != name.job)) return false;
250 if (pattern.has_replica && (pattern.replica != name.replica)) return false;
251 if (pattern.has_task && (pattern.task != name.task)) return false;
252 if (pattern.has_type && (pattern.type != name.type)) return false;
253 if (pattern.has_id && (pattern.id != name.id)) return false;
254 return true;
255 }
256
257 /* static */
MergeDevNames(ParsedName * target,const ParsedName & other,bool allow_soft_placement)258 Status DeviceNameUtils::MergeDevNames(ParsedName* target,
259 const ParsedName& other,
260 bool allow_soft_placement) {
261 if (other.has_job) {
262 if (target->has_job && target->job != other.job) {
263 return errors::InvalidArgument(
264 "Cannot merge devices with incompatible jobs: '",
265 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
266 "'");
267 } else {
268 target->has_job = other.has_job;
269 target->job = other.job;
270 }
271 }
272
273 if (other.has_replica) {
274 if (target->has_replica && target->replica != other.replica) {
275 return errors::InvalidArgument(
276 "Cannot merge devices with incompatible replicas: '",
277 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
278 "'");
279 } else {
280 target->has_replica = other.has_replica;
281 target->replica = other.replica;
282 }
283 }
284
285 if (other.has_task) {
286 if (target->has_task && target->task != other.task) {
287 return errors::InvalidArgument(
288 "Cannot merge devices with incompatible tasks: '",
289 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
290 "'");
291 } else {
292 target->has_task = other.has_task;
293 target->task = other.task;
294 }
295 }
296
297 if (other.has_type) {
298 if (target->has_type && target->type != other.type) {
299 if (!allow_soft_placement) {
300 return errors::InvalidArgument(
301 "Cannot merge devices with incompatible types: '",
302 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
303 "'");
304 } else {
305 target->has_id = false;
306 target->has_type = false;
307 return Status::OK();
308 }
309 } else {
310 target->has_type = other.has_type;
311 target->type = other.type;
312 }
313 }
314
315 if (other.has_id) {
316 if (target->has_id && target->id != other.id) {
317 if (!allow_soft_placement) {
318 return errors::InvalidArgument(
319 "Cannot merge devices with incompatible ids: '",
320 ParsedNameToString(*target), "' and '", ParsedNameToString(other),
321 "'");
322 } else {
323 target->has_id = false;
324 return Status::OK();
325 }
326 } else {
327 target->has_id = other.has_id;
328 target->id = other.id;
329 }
330 }
331
332 return Status::OK();
333 }
334
335 /* static */
IsSameAddressSpace(const ParsedName & a,const ParsedName & b)336 bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
337 const ParsedName& b) {
338 return (a.has_job && b.has_job && (a.job == b.job)) &&
339 (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
340 (a.has_task && b.has_task && (a.task == b.task));
341 }
342
343 /* static */
IsSameAddressSpace(StringPiece src,StringPiece dst)344 bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
345 ParsedName x;
346 ParsedName y;
347 return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
348 IsSameAddressSpace(x, y);
349 }
350
351 /* static */
LocalName(StringPiece type,int id)352 string DeviceNameUtils::LocalName(StringPiece type, int id) {
353 return strings::StrCat("/device:", type, ":", id);
354 }
355
356 namespace {
357 // Returns the legacy local device name given its "type" and "id" (which is
358 // '/device:type:id').
LegacyLocalName(StringPiece type,int id)359 string LegacyLocalName(StringPiece type, int id) {
360 return strings::StrCat(type, ":", id);
361 }
362 } // anonymous namespace
363
364 /* static */
LocalName(StringPiece fullname)365 string DeviceNameUtils::LocalName(StringPiece fullname) {
366 ParsedName x;
367 CHECK(ParseFullName(fullname, &x)) << fullname;
368 return LocalName(x.type, x.id);
369 }
370
371 /* static */
ParseLocalName(StringPiece name,ParsedName * p)372 bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
373 if (!ConsumeDeviceType(&name, &p->type)) {
374 return false;
375 }
376 p->has_type = true;
377 if (!str_util::ConsumePrefix(&name, ":")) {
378 return false;
379 }
380 if (!ConsumeNumber(&name, &p->id)) {
381 return false;
382 }
383 p->has_id = true;
384 return name.empty();
385 }
386
387 /* static */
SplitDeviceName(StringPiece name,string * task,string * device)388 bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
389 string* device) {
390 ParsedName pn;
391 if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
392 task->clear();
393 task->reserve(
394 (pn.has_job ? (5 + pn.job.size()) : 0) +
395 (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
396 (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
397 if (pn.has_job) {
398 strings::StrAppend(task, "/job:", pn.job);
399 }
400 if (pn.has_replica) {
401 strings::StrAppend(task, "/replica:", pn.replica);
402 }
403 if (pn.has_task) {
404 strings::StrAppend(task, "/task:", pn.task);
405 }
406 device->clear();
407 strings::StrAppend(device, pn.type, ":", pn.id);
408 return true;
409 }
410 return false;
411 }
412
GetNamesForDeviceMappings(const ParsedName & pn)413 std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
414 const ParsedName& pn) {
415 if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
416 return {
417 DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
418 LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
419 } else {
420 return {};
421 }
422 }
423
GetLocalNamesForDeviceMappings(const ParsedName & pn)424 std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
425 const ParsedName& pn) {
426 if (pn.has_type && pn.has_id) {
427 return {DeviceNameUtils::LocalName(pn.type, pn.id),
428 LegacyLocalName(pn.type, pn.id)};
429 } else {
430 return {};
431 }
432 }
433
434 } // namespace tensorflow
435