1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47 // For TPU distributed rewrite, TPU args are collected and "staged" on the local
48 // host using an IdentityN TF op. Some args may result from a remote source.
49 // When all arg tensors are available, the TPUExecute op can be inovoked. See
50 // DistributedTPURewritePass for more details.
51 const char* const kTpuExecuteStagingOp = "IdentityN";
52 const char* const kTpuExecuteStagingNodeName = "_variable_copy";
53
AttrSlice()54 AttrSlice::AttrSlice() : ndef_(nullptr) {
55 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
56 attrs_ = kEmptyAttrValueMap;
57 }
58
59 // Do not cache the map field reference because that may be invalidated on
60 // Clear.
AttrSlice(const NodeDef & node_def)61 AttrSlice::AttrSlice(const NodeDef& node_def)
62 : ndef_(&node_def), attrs_(nullptr) {}
63
AttrSlice(const AttrValueMap * a)64 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
65
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)66 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
67 string ret;
68
69 // We sort the attrs so the output is deterministic.
70 std::vector<string> attr_names;
71 attr_names.reserve(attrs.size());
72 for (const auto& attr : attrs) {
73 attr_names.push_back(attr.first);
74 }
75 std::sort(attr_names.begin(), attr_names.end());
76 bool first = true;
77 for (const string& attr_name : attr_names) {
78 if (!first) strings::StrAppend(&ret, ", ");
79 first = false;
80 strings::StrAppend(&ret, attr_name, "=",
81 SummarizeAttrValue(*attrs.Find(attr_name)));
82 }
83
84 // Consider the device to be a final attr with name "_device".
85 if (!device.empty()) {
86 if (!first) strings::StrAppend(&ret, ", ");
87 first = false;
88 strings::StrAppend(&ret, "_device=\"", device, "\"");
89 }
90 return ret;
91 }
92
SummarizeNode() const93 string AttrSlice::SummarizeNode() const {
94 return ndef_ ? SummarizeNodeDef(*ndef_)
95 : strings::StrCat(
96 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
97 }
98
DebugString() const99 string AttrSlice::DebugString() const {
100 std::vector<string> attr_key_vals;
101 attr_key_vals.reserve(attrs()->size());
102 for (const auto& it : *this) {
103 const string& name = it.first;
104 const AttrValue& attr_value = it.second;
105 attr_key_vals.push_back(
106 absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
107 }
108 return absl::StrJoin(attr_key_vals, ", ");
109 }
110
SummarizeNodeDef(const NodeDef & node_def,int max_inputs_in_summary)111 string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
112 string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
113 " = ", node_def.op(), "[");
114 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
115 strings::StrAppend(&ret, "](");
116
117 // Output inputs, including control inputs, verbatim.
118 bool first = true;
119 for (const string& input : node_def.input()) {
120 if (!first) strings::StrAppend(&ret, ", ");
121 first = false;
122 if (max_inputs_in_summary-- == 0) {
123 strings::StrAppend(&ret, "...");
124 break;
125 }
126 strings::StrAppend(&ret, input);
127 }
128 strings::StrAppend(&ret, ")");
129 return ret;
130 }
131
SummarizeAttrs(const NodeDef & node_def)132 string SummarizeAttrs(const NodeDef& node_def) {
133 return SummarizeAttrsHelper(node_def, node_def.device());
134 }
135
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)136 string FormatNodeDefForError(
137 StringPiece node_name, bool has_experimental_debug_info,
138 const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
139 return !has_experimental_debug_info ||
140 experimental_debug_info.original_node_names().empty()
141 ? errors::FormatNodeNameForError(string(node_name))
142 : errors::FormatOriginalNodeLocationForError(
143 experimental_debug_info.original_node_names(),
144 experimental_debug_info.original_func_names());
145 }
146
FormatNodeDefForError(const NodeDef & node_def)147 string FormatNodeDefForError(const NodeDef& node_def) {
148 return FormatNodeDefForError(node_def.name(),
149 node_def.has_experimental_debug_info(),
150 node_def.experimental_debug_info());
151 }
152
Find(StringPiece attr_name) const153 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
154 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
155 // requires that the keys used for lookups have type 'const string&'. Because
156 // this method takes a StringPiece, it is necessary to allocate a temporary
157 // string, copy attr_name to it, and then use that temporary string for the
158 // lookup. This causes an excessive number of short-lived allocations, and for
159 // large graphs, this can be a significant cost.
160 //
161 // Because most nodes have a small number of attributes, a simple linear scan
162 // is generally more efficient than a hashed lookup. If google::protobuf::Map
163 // changes so that it supports efficient lookups using StringPiece instead of
164 // const string&, then this code could be changed to use attrs()->find()
165 // again.
166
167 for (const auto& attr : *attrs()) {
168 if (attr.first == attr_name) {
169 return &attr.second;
170 }
171 }
172 return nullptr;
173 }
174
FindByString(const string & attr_name) const175 const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
176 auto iter = attrs()->find(attr_name);
177 if (iter != attrs()->end()) {
178 return &iter->second;
179 } else {
180 return nullptr;
181 }
182 }
183
CheckFind(StringPiece attr_name,const AttrValue * attr_value) const184 Status AttrSlice::CheckFind(StringPiece attr_name,
185 const AttrValue* attr_value) const {
186 if (attr_value != nullptr) {
187 return OkStatus();
188 }
189 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
190 // Skip AttachDef for internal attrs since it is a little bit
191 // expensive and it is common for them to correctly not be included
192 // in a NodeDef.
193 if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
194 s = AttachDef(s, *ndef_);
195 }
196 return s;
197 }
198
Find(StringPiece attr_name,const AttrValue ** attr_value) const199 Status AttrSlice::Find(StringPiece attr_name,
200 const AttrValue** attr_value) const {
201 *attr_value = Find(attr_name);
202 return CheckFind(attr_name, *attr_value);
203 }
204
FindByString(const string & attr_name,const AttrValue ** attr_value) const205 Status AttrSlice::FindByString(const string& attr_name,
206 const AttrValue** attr_value) const {
207 *attr_value = FindByString(attr_name);
208 return CheckFind(attr_name, *attr_value);
209 }
210
EqualAttrs(AttrSlice other,Scratch * scratch) const211 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
212 if (size() != other.size()) return false;
213
214 for (const auto& attr : *other.attrs()) {
215 auto iter = attrs()->find(attr.first);
216 if (iter == attrs()->end()) return false;
217 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
218 // TensorProto is a nonunique representation of Tensor. This bug will go
219 // away once AttrSlice switches over to NodeInfo.
220 iter->second.SerializeToString(&scratch->a);
221 attr.second.SerializeToString(&scratch->b);
222 if (scratch->a != scratch->b) return false;
223 }
224 return true;
225 }
226
227 // The ... is to allow the caller to inject some value validation code. Use
228 // just ; if no additional validation code is needed.
229 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
230 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
231 TYPE* value) { \
232 const AttrValue* attr_value; \
233 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
234 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
235 const auto& v = attr_value->FIELD(); \
236 __VA_ARGS__; \
237 *value = CAST; \
238 return OkStatus(); \
239 } \
240 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
241 std::vector<TYPE>* value) { \
242 const AttrValue* attr_value; \
243 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
244 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
245 value->reserve(attr_value->list().FIELD().size()); \
246 for (const auto& v : attr_value->list().FIELD()) { \
247 __VA_ARGS__; \
248 value->APPEND_OP(CAST); \
249 } \
250 return OkStatus(); \
251 }
252
253 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
254 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
255 TYPE* value) { \
256 const AttrValue* attr_value = attrs.Find(attr_name); \
257 if (attr_value == nullptr) { \
258 return false; \
259 } \
260 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
261 if (!s.ok()) { \
262 return false; \
263 } \
264 const auto& v = attr_value->FIELD(); \
265 __VA_ARGS__; \
266 *value = CAST; \
267 return true; \
268 } \
269 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
270 std::vector<TYPE>* value) { \
271 const AttrValue* attr_value = attrs.Find(attr_name); \
272 if (attr_value == nullptr) { \
273 return false; \
274 } \
275 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
276 if (!s.ok()) { \
277 return false; \
278 } \
279 value->reserve(attr_value->list().FIELD().size()); \
280 for (const auto& v : attr_value->list().FIELD()) { \
281 __VA_ARGS__; \
282 value->APPEND_OP(CAST); \
283 } \
284 return true; \
285 }
286 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
287 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
288 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
289 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
290 DEFINE_GET_ATTR(int64_t, i, "int", emplace_back, v, ;)
291 DEFINE_TRY_GET_ATTR(int64_t, i, "int", emplace_back, v, ;)
292 DEFINE_GET_ATTR(
293 int32, i, "int", emplace_back, static_cast<int32>(v),
294 if (static_cast<int64_t>(static_cast<int32>(v)) != v) {
295 return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
296 " out of range for an int32");
297 })
298 DEFINE_TRY_GET_ATTR(
299 int32, i, "int", emplace_back, static_cast<int32>(v),
300 if (static_cast<int64_t>(static_cast<int32>(v)) != v) {
301 static int log_counter = 0;
302 if (log_counter < 10) {
303 log_counter++;
304 LOG(WARNING) << "Attr " << attr_name << " has value " << v
305 << " out of range for an int32";
306 }
307 return false;
308 })
309 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
310 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
311 DEFINE_GET_ATTR(bool, b, "bool", emplace_back, v, ;)
312 DEFINE_TRY_GET_ATTR(bool, b, "bool", emplace_back, v, ;)
313 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
314 ;)
315 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
316 static_cast<DataType>(v),
317 ;)
318 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
319 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
320 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
321 DEFINE_TRY_GET_ATTR(
322 TensorShape, shape, "shape", emplace_back, TensorShape(v),
323 if (!TensorShape::IsValidShape(v).ok()) {
324 static int log_counter = 0;
325 if (log_counter < 10) {
326 log_counter++;
327 LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
328 << v.DebugString();
329 }
330 return false;
331 })
332 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
333 PartialTensorShape(v),
334 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
335 DEFINE_GET_ATTR(
336 Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
337 return errors::InvalidArgument("Attr ", attr_name, " has value ",
338 v.ShortDebugString(),
339 " that can't be converted to a Tensor");
340 })
341 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
342 #undef DEFINE_GET_ATTR
343
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)344 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
345 return node_def.attr().find(string(attr_name)) != node_def.attr().end();
346 }
347
348 static const string& kEmptyString = *new string();
349
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)350 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
351 const AttrValue* attr_value = attrs.Find(attr_name);
352 if (attr_value == nullptr) {
353 return kEmptyString;
354 }
355 Status s = AttrValueHasType(*attr_value, "string");
356 if (!s.ok()) {
357 return kEmptyString;
358 }
359 return attr_value->s();
360 }
361
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)362 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
363 std::vector<const string*>* value) {
364 const AttrValue* attr_value = attrs.Find(attr_name);
365 if (attr_value == nullptr) {
366 return false;
367 }
368 Status s = AttrValueHasType(*attr_value, "list(string)");
369 if (!s.ok()) {
370 return false;
371 }
372 value->reserve(attr_value->list().s().size());
373 for (const auto& v : attr_value->list().s()) {
374 value->push_back(&v);
375 }
376 return true;
377 }
378
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)379 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
380 std::vector<const TensorShapeProto*>* value) {
381 const AttrValue* attr_value = attrs.Find(attr_name);
382 if (attr_value == nullptr) {
383 return false;
384 }
385 Status s = AttrValueHasType(*attr_value, "list(shape)");
386 if (!s.ok()) {
387 return false;
388 }
389 value->reserve(attr_value->list().shape().size());
390 for (const auto& v : attr_value->list().shape()) {
391 value->push_back(&v);
392 }
393 return true;
394 }
395
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)396 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
397 DataTypeVector* value) {
398 const AttrValue* attr_value;
399 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
400 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
401 for (const auto& v : attr_value->list().type()) {
402 value->push_back(static_cast<DataType>(v));
403 }
404 return OkStatus();
405 }
406
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)407 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
408 const TensorProto** value) {
409 const AttrValue* attr_value;
410 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
411 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
412 *value = &attr_value->tensor();
413 return OkStatus();
414 }
415
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)416 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
417 const TensorProto** value) {
418 const AttrValue* attr_value = attrs.Find(attr_name);
419 if (attr_value == nullptr) {
420 return false;
421 }
422 Status s = AttrValueHasType(*attr_value, "tensor");
423 if (!s.ok()) {
424 return false;
425 }
426 *value = &attr_value->tensor();
427 return true;
428 }
429
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)430 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
431 const NameAttrList** value) {
432 const AttrValue* attr_value;
433 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
434 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
435 *value = &attr_value->func();
436 return OkStatus();
437 }
438
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)439 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
440 const NameAttrList** value) {
441 const AttrValue* attr_value = attrs.Find(attr_name);
442 if (attr_value == nullptr) {
443 return false;
444 }
445 Status s = AttrValueHasType(*attr_value, "func");
446 if (!s.ok()) {
447 return false;
448 }
449 *value = &attr_value->func();
450 return true;
451 }
452
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,Padding * value)453 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
454 Padding* value) {
455 string str_value;
456 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
457 return GetPaddingFromString(str_value, value);
458 }
459
460 namespace { // Helper for InOutTypesForNode().
461
462 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)463 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
464 const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
465 const int original_size = sig->size();
466 if (!arg_def.number_attr().empty()) {
467 // Same type repeated "repeats" times.
468 int64_t repeats = -1;
469 TF_RETURN_IF_ERROR(
470 GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
471 // We can't handle outputs that are larger than int32 sizes.
472 if (static_cast<int64_t>(static_cast<int32>(repeats)) != repeats) {
473 return errors::InvalidArgument("Number of outputs is too big: ", repeats);
474 }
475 if (repeats < 0) {
476 return errors::InvalidArgument("Value for number_attr() ", repeats,
477 " < 0");
478 }
479
480 if (!arg_def.type_attr().empty()) {
481 DataType dtype;
482 TF_RETURN_IF_ERROR(
483 GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
484 for (int i = 0; i < repeats; ++i) {
485 sig->push_back(dtype);
486 }
487 } else if (arg_def.type() != DT_INVALID) {
488 for (int i = 0; i < repeats; ++i) {
489 sig->push_back(arg_def.type());
490 }
491 } else {
492 return errors::InvalidArgument("Missing type or type_attr field in ",
493 arg_def.ShortDebugString());
494 }
495 } else if (!arg_def.type_attr().empty()) {
496 const AttrValue* attr_value;
497 TF_RETURN_IF_ERROR(AttrSlice(node_or_attrs)
498 .FindByString(arg_def.type_attr(), &attr_value));
499 sig->push_back(attr_value->type());
500 } else if (!arg_def.type_list_attr().empty()) {
501 const AttrValue* attr_value;
502 TF_RETURN_IF_ERROR(
503 AttrSlice(node_or_attrs)
504 .FindByString(arg_def.type_list_attr(), &attr_value));
505 for (int dtype : attr_value->list().type()) {
506 sig->push_back(static_cast<DataType>(dtype));
507 }
508 } else if (arg_def.type() != DT_INVALID) {
509 sig->push_back(arg_def.type());
510 } else {
511 return errors::InvalidArgument("No type fields in ",
512 arg_def.ShortDebugString());
513 }
514 if (arg_def.is_ref()) {
515 // For all types that were added by this function call, make them refs.
516 for (size_t i = original_size; i < sig->size(); ++i) {
517 if (IsRefType((*sig)[i])) {
518 return errors::InvalidArgument(
519 "Requested reference to a reference type: ",
520 arg_def.ShortDebugString());
521 }
522 (*sig)[i] = MakeRefType((*sig)[i]);
523 }
524 }
525 return OkStatus();
526 }
527
528 } // namespace
529
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)530 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
531 int input_port, DataType* input_type) {
532 DataTypeVector input_types;
533 for (const auto& arg : op_def.input_arg()) {
534 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
535 int input_types_size = input_types.size();
536 if (input_types_size > input_port) {
537 const DataType dtype = input_types[input_port];
538 *input_type = dtype;
539 return OkStatus();
540 }
541 }
542 return errors::InvalidArgument("Input ", input_port, " not found for node ",
543 node_def.name());
544 }
545
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)546 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
547 DataTypeVector* inputs) {
548 for (const auto& arg : op_def.input_arg()) {
549 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
550 }
551 return OkStatus();
552 }
553
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)554 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
555 int output_port, DataType* output_type) {
556 DataTypeVector output_types;
557 for (const auto& arg : op_def.output_arg()) {
558 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
559 int output_types_size = output_types.size();
560 if (output_types_size > output_port) {
561 const DataType dtype = output_types[output_port];
562 *output_type = dtype;
563 return OkStatus();
564 }
565 }
566 return errors::InvalidArgument("Output ", output_port, " not found for node ",
567 node_def.name());
568 }
569
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)570 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
571 DataTypeVector* outputs) {
572 for (const auto& arg : op_def.output_arg()) {
573 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
574 }
575 return OkStatus();
576 }
577
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)578 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
579 DataTypeVector* outputs) {
580 for (const auto& arg : op_def.output_arg()) {
581 TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
582 }
583 return OkStatus();
584 }
585
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)586 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
587 DataTypeVector* inputs, DataTypeVector* outputs) {
588 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
589 return OutputTypesForNode(node_def, op_def, outputs);
590 }
591
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)592 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
593 int* num_outputs) {
594 DataTypeVector outputs;
595 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
596 *num_outputs = outputs.size();
597 return OkStatus();
598 }
599
OpPortIdToArgId(const NodeDef & node,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,int port_id)600 int OpPortIdToArgId(const NodeDef& node,
601 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
602 int port_id) {
603 for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
604 if (port_id < 0) {
605 return -1;
606 } else if (port_id == 0) {
607 return arg_id;
608 }
609
610 // Default is 1 port per arg.
611 int n = 1;
612
613 const auto& arg = args.Get(arg_id);
614 if (!arg.number_attr().empty()) {
615 n = node.attr().at(arg.number_attr()).i();
616 } else if (!arg.type_list_attr().empty()) {
617 n = node.attr().at(arg.type_list_attr()).list().type_size();
618 }
619
620 if (n < 0) {
621 // This should never happen.
622 DCHECK_GE(n, 0);
623 return -1;
624 } else if (port_id < n) {
625 return arg_id;
626 }
627 port_id -= n;
628 }
629
630 return -1;
631 }
632
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)633 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
634 if (node_def.op() != op_def.name()) {
635 return errors::InvalidArgument(
636 "NodeDef op '", node_def.op(), "' does not match ",
637 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
638 }
639
640 bool seen_control = false;
641 size_t num_inputs = 0;
642 // TODO(josh11b): Unify the input field validation.
643 for (const string& input : node_def.input()) {
644 if (absl::StartsWith(input, "^")) {
645 seen_control = true;
646 if (input.find(':') != string::npos) {
647 return errors::InvalidArgument("Control input '", input,
648 "' must not have ':' in NodeDef: ",
649 FormatNodeDefForError(node_def));
650 }
651 } else if (seen_control) {
652 return errors::InvalidArgument("Non-control input '", input,
653 "' after control input in NodeDef: ",
654 FormatNodeDefForError(node_def));
655 } else {
656 ++num_inputs;
657 }
658 }
659
660 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
661 for (const auto& attr : op_def.attr()) {
662 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
663 return errors::InvalidArgument("OpDef has duplicate attr name '",
664 attr.name(),
665 "': ", SummarizeOpDef(op_def));
666 }
667 }
668 for (const auto& attr : node_def.attr()) {
669 // Allow internal optional attributes with names starting with "_".
670 if (absl::StartsWith(attr.first, "_")) {
671 continue;
672 }
673 auto iter = op_attrs.find(attr.first);
674 if (iter == op_attrs.end()) {
675 LOG_EVERY_N_SEC(ERROR, 5)
676 << "NodeDef mentions attribute " << attr.first
677 << " which is not in the op definition: " << SummarizeOpDef(op_def)
678 << " This may be expected if your graph generating binary is newer "
679 << " than this binary. Unknown attributes will be ignored."
680 << " NodeDef: " << FormatNodeDefForError(node_def);
681 continue;
682 }
683
684 // If attr value is placeholder, do not check it.
685 if (attr.second.placeholder().empty()) {
686 TF_RETURN_WITH_CONTEXT_IF_ERROR(
687 ValidateAttrValue(attr.second, *iter->second),
688 "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
689 SummarizeOpDef(op_def));
690 }
691
692 // Keep track of which attr names have (not) been found in the NodeDef.
693 op_attrs.erase(iter);
694 }
695
696 // Were all attrs in the OpDef found in the NodeDef?
697 if (!op_attrs.empty()) {
698 string attrs;
699 for (const auto& attr_pair : op_attrs) {
700 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
701 strings::StrAppend(&attrs, attr_pair.first);
702 }
703 return errors::InvalidArgument(
704 "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
705 "' from ", SummarizeOpDef(op_def),
706 "; NodeDef: ", FormatNodeDefForError(node_def));
707 }
708
709 // Validate the number of inputs.
710 DataTypeVector inputs, outputs;
711 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
712
713 if (num_inputs != inputs.size()) {
714 return errors::InvalidArgument(
715 "NodeDef expected inputs '", DataTypeVectorString(inputs),
716 "' do not match ", num_inputs, " inputs specified; ",
717 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
718 }
719
720 return OkStatus();
721 }
722
723 namespace { // Helpers for NameRangesForNode()
724
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)725 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
726 const OpDef& op_def, int* num) {
727 if (!arg_def.number_attr().empty()) {
728 // Same type repeated "num" times.
729 return GetNodeAttr(attrs, arg_def.number_attr(), num);
730 } else if (!arg_def.type_list_attr().empty()) {
731 const AttrValue* attr_value;
732 TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
733 *num = attr_value->list().type_size();
734 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
735 *num = 1;
736 } else {
737 return errors::InvalidArgument(
738 "Argument '", arg_def.name(),
739 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
740 }
741 return OkStatus();
742 }
743
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)744 Status NameRangesHelper(const AttrSlice& attrs,
745 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
746 const OpDef& op_def, NameRangeMap* result) {
747 int start = 0;
748 int num;
749 for (const auto& arg : args) {
750 TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
751 (*result)[arg.name()] = std::make_pair(start, start + num);
752 start += num;
753 }
754 return OkStatus();
755 }
756
757 } // namespace
758
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)759 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
760 NameRangeMap* inputs, NameRangeMap* outputs) {
761 if (inputs != nullptr) {
762 TF_RETURN_IF_ERROR(
763 NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
764 }
765 if (outputs != nullptr) {
766 return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
767 }
768 return OkStatus();
769 }
770
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)771 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
772 for (const auto& attr_def : op_def.attr()) {
773 AttrSlice attrs(*node_def);
774 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
775 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
776 }
777 }
778 }
779
StripDefaultsFromNodeDef(const OpDef & op_def,NodeDef * node_def)780 void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def) {
781 AttrSlice attrs(*node_def);
782 for (const auto& attr_def : op_def.attr()) {
783 if (attr_def.has_default_value()) {
784 const AttrValue* attr = attrs.Find(attr_def.name());
785 if (attr && AreAttrValuesEqual(*attr, attr_def.default_value()))
786 node_def->mutable_attr()->erase(attr_def.name());
787 }
788 }
789 }
790
791 namespace {
792
793 using ::tensorflow::tstring;
794 using ::tensorflow::strings::Scanner;
795
IsValidNodeName(StringPiece sp)796 bool IsValidNodeName(StringPiece sp) {
797 Scanner scanner(sp);
798 scanner.One(Scanner::LETTER_DIGIT_DOT)
799 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
800
801 while (true) {
802 if (!scanner.GetResult()) // Some error in previous iteration.
803 return false;
804 if (scanner.empty()) // No error, but nothing left, good.
805 return true;
806
807 // Absorb another name/namespace, starting with a '>'
808 scanner.One(Scanner::RANGLE)
809 .One(Scanner::LETTER_DIGIT_DOT)
810 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
811 }
812 }
813
IsValidDataInputName(StringPiece sp)814 bool IsValidDataInputName(StringPiece sp) {
815 // Data inputs are op_name, op_name:0, or op_name:12345.
816 Scanner scan(sp);
817 scan.One(Scanner::LETTER_DIGIT_DOT)
818 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
819
820 while (true) {
821 if (!scan.GetResult()) // Some error in previous iteration.
822 return false;
823 if (scan.empty()) // No error, but nothing left, good.
824 return true;
825
826 if (scan.Peek() == ':') { // Absorb identifier after the colon
827 scan.OneLiteral(":");
828 if (scan.Peek() == '0') {
829 scan.OneLiteral("0"); // :0
830 } else {
831 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
832 }
833 } else {
834 // Absorb another name/namespace, starting with a '>'
835 scan.One(Scanner::RANGLE)
836 .One(Scanner::LETTER_DIGIT_DOT)
837 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
838 }
839 }
840 }
841
IsValidControlInputName(StringPiece sp)842 bool IsValidControlInputName(StringPiece sp) {
843 Scanner scan(sp);
844 scan.OneLiteral("^")
845 .One(Scanner::LETTER_DIGIT_DOT)
846 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
847
848 while (true) {
849 if (!scan.GetResult()) // Some error in previous iteration.
850 return false;
851 if (scan.empty()) // No error, but nothing left, good.
852 return true;
853
854 // Absorb another name/namespace, starting with a '>'
855 scan.One(Scanner::RANGLE)
856 .One(Scanner::LETTER_DIGIT_DOT)
857 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
858 }
859 }
860
861 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
862
863 } // namespace
864
ValidateOpInput(const string & input_name,bool * is_control_input)865 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
866 *is_control_input = false;
867 if (IsValidDataInputName(input_name)) {
868 return OkStatus();
869 } else if (IsValidControlInputName(input_name)) {
870 *is_control_input = true;
871 return OkStatus();
872 } else {
873 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
874 }
875 }
876
ValidateNodeName(const string & node_name)877 Status ValidateNodeName(const string& node_name) {
878 if (IsValidNodeName(node_name)) {
879 return OkStatus();
880 } else {
881 return errors::InvalidArgument("Illegal op name '", node_name, "'");
882 }
883 }
884
ValidateExternalNodeDefSyntax(const NodeDef & node_def)885 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
886 Status s = ValidateNodeName(node_def.name());
887 if (!s.ok()) {
888 return AttachDef(s, node_def);
889 }
890 bool in_control_inputs = false;
891 for (const string& input_name : node_def.input()) {
892 bool is_control_input;
893 s = ValidateOpInput(input_name, &is_control_input);
894 if (!s.ok()) {
895 return AttachDef(s, node_def);
896 }
897
898 if (in_control_inputs && !is_control_input) {
899 return AttachDef(errors::InvalidArgument(
900 "All control inputs must follow all data inputs"),
901 node_def);
902 }
903 in_control_inputs = is_control_input;
904 }
905 return OkStatus();
906 }
907
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)908 Status AttachDef(const Status& status, const NodeDef& node_def,
909 bool allow_multiple_formatted_node) {
910 Status ret = status;
911 string node_error;
912 if (!allow_multiple_formatted_node &&
913 status.error_message().find("{{node ") != string::npos) {
914 node_error = node_def.name();
915 } else {
916 node_error = FormatNodeDefForError(node_def);
917 }
918 errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
919 return ret;
920 }
921
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)922 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
923 node_def->mutable_attr()->insert(
924 AttrValueMap::value_type(string(name), value));
925 }
926
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)927 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
928 (*node_def->mutable_attr())[string(name)] = std::move(value);
929 }
930
931 #define ADD_NODE_ATTR(T) \
932 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
933 AttrValue attr_value; \
934 SetAttrValue(value, &attr_value); \
935 AddNodeAttr(name, attr_value, node_def); \
936 }
937 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)938 ADD_NODE_ATTR(const char*)
939 ADD_NODE_ATTR(int32_t)
940 ADD_NODE_ATTR(int64_t)
941 ADD_NODE_ATTR(float)
942 ADD_NODE_ATTR(double)
943 ADD_NODE_ATTR(bool)
944 ADD_NODE_ATTR(DataType)
945 ADD_NODE_ATTR(const PartialTensorShape&)
946 ADD_NODE_ATTR(const Tensor&)
947 ADD_NODE_ATTR(const TensorProto&)
948 ADD_NODE_ATTR(const NameAttrList&)
949 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
950 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
951 ADD_NODE_ATTR(gtl::ArraySlice<string>)
952 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
953 ADD_NODE_ATTR(gtl::ArraySlice<int64_t>)
954 ADD_NODE_ATTR(gtl::ArraySlice<float>)
955 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
956 ADD_NODE_ATTR(const std::vector<bool>&)
957 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
958 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
959 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
960 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
961 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
962 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
963 #undef ADD_NODE_ATTR
964
965 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
966 map->insert(AttrValueMap::value_type(string(name), value));
967 }
968
969 #define ADD_ATTR(T) \
970 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
971 AttrValue attr_value; \
972 SetAttrValue(value, &attr_value); \
973 AddAttr(name, attr_value, map); \
974 }
ADD_ATTR(bool)975 ADD_ATTR(bool)
976 #undef ADD_ATTR
977
978 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
979 NodeDef* node_def, bool uniquify_frame_name) {
980 node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
981
982 // Update frame name to avoid multiple LoopCond nodes in one frame.
983 if (uniquify_frame_name &&
984 (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
985 string frame_name;
986 TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
987 AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
988 frame_name = strings::StrCat(prefix, frame_name, suffix);
989 attr.set_s(frame_name);
990 }
991
992 return OkStatus();
993 }
994
MaybeAddPrefixToColocationConstraints(const std::unordered_set<string> & match,StringPiece prefix,NodeDef * node_def)995 Status MaybeAddPrefixToColocationConstraints(
996 const std::unordered_set<string>& match, StringPiece prefix,
997 NodeDef* node_def) {
998 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
999 if (attr == node_def->mutable_attr()->end()) {
1000 return OkStatus();
1001 }
1002 auto constraints_list = attr->second.mutable_list();
1003 auto constraints_size = constraints_list->s_size();
1004 for (size_t i = 0; i < constraints_size; ++i) {
1005 StringPiece original(constraints_list->s(i));
1006 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
1007 if (match.find(string(original)) != match.end()) {
1008 (*constraints_list->mutable_s(i)) =
1009 strings::StrCat(kColocationGroupPrefix, prefix, original);
1010 }
1011 }
1012 }
1013 return OkStatus();
1014 }
1015
MaybeUpdateColocationConstraintsWithMap(const std::map<absl::string_view,absl::string_view> & node_name_map,NodeDef * node_def)1016 Status MaybeUpdateColocationConstraintsWithMap(
1017 const std::map<absl::string_view, absl::string_view>& node_name_map,
1018 NodeDef* node_def) {
1019 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
1020 if (attr == node_def->mutable_attr()->end()) {
1021 return OkStatus();
1022 }
1023 auto constraints_list = attr->second.mutable_list();
1024 auto constraints_size = constraints_list->s_size();
1025 for (size_t i = 0; i < constraints_size; ++i) {
1026 StringPiece original(constraints_list->s(i));
1027 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
1028 if (node_name_map.find(original) != node_name_map.end()) {
1029 (*constraints_list->mutable_s(i)) =
1030 strings::StrCat(kColocationGroupPrefix, node_name_map.at(original));
1031 }
1032 }
1033 }
1034 return OkStatus();
1035 }
1036
1037 } // namespace tensorflow