1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47
AttrSlice()48 AttrSlice::AttrSlice() : ndef_(nullptr) {
49 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
50 attrs_ = kEmptyAttrValueMap;
51 }
52
AttrSlice(const NodeDef & node_def)53 AttrSlice::AttrSlice(const NodeDef& node_def)
54 : ndef_(&node_def), attrs_(&ndef_->attr()) {}
55
AttrSlice(const AttrValueMap * a)56 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
57
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)58 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
59 string ret;
60
61 // We sort the attrs so the output is deterministic.
62 std::vector<string> attr_names;
63 attr_names.reserve(attrs.size());
64 for (const auto& attr : attrs) {
65 attr_names.push_back(attr.first);
66 }
67 std::sort(attr_names.begin(), attr_names.end());
68 bool first = true;
69 for (const string& attr_name : attr_names) {
70 if (!first) strings::StrAppend(&ret, ", ");
71 first = false;
72 strings::StrAppend(&ret, attr_name, "=",
73 SummarizeAttrValue(*attrs.Find(attr_name)));
74 }
75
76 // Consider the device to be a final attr with name "_device".
77 if (!device.empty()) {
78 if (!first) strings::StrAppend(&ret, ", ");
79 first = false;
80 strings::StrAppend(&ret, "_device=\"", device, "\"");
81 }
82 return ret;
83 }
84
SummarizeNode() const85 string AttrSlice::SummarizeNode() const {
86 return ndef_ ? SummarizeNodeDef(*ndef_)
87 : strings::StrCat(
88 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
89 }
90
DebugString() const91 string AttrSlice::DebugString() const {
92 std::vector<string> attr_key_vals;
93 attr_key_vals.reserve(attrs_->size());
94 for (const auto& it : *this) {
95 const string& name = it.first;
96 const AttrValue& attr_value = it.second;
97 attr_key_vals.push_back(
98 absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
99 }
100 return absl::StrJoin(attr_key_vals, ", ");
101 }
102
SummarizeNodeDef(const NodeDef & node_def,int max_inputs_in_summary)103 string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
104 string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
105 " = ", node_def.op(), "[");
106 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
107 strings::StrAppend(&ret, "](");
108
109 // Output inputs, including control inputs, verbatim.
110 bool first = true;
111 for (const string& input : node_def.input()) {
112 if (!first) strings::StrAppend(&ret, ", ");
113 first = false;
114 if (max_inputs_in_summary-- == 0) {
115 strings::StrAppend(&ret, "...");
116 break;
117 }
118 strings::StrAppend(&ret, input);
119 }
120 strings::StrAppend(&ret, ")");
121 return ret;
122 }
123
SummarizeAttrs(const NodeDef & node_def)124 string SummarizeAttrs(const NodeDef& node_def) {
125 return SummarizeAttrsHelper(node_def, node_def.device());
126 }
127
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)128 string FormatNodeDefForError(
129 StringPiece node_name, bool has_experimental_debug_info,
130 const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
131 return !has_experimental_debug_info ||
132 experimental_debug_info.original_node_names().empty()
133 ? errors::FormatNodeNameForError(string(node_name))
134 : errors::FormatNodeNamesForError(
135 experimental_debug_info.original_node_names());
136 }
137
FormatNodeDefForError(const NodeDef & node_def)138 string FormatNodeDefForError(const NodeDef& node_def) {
139 return FormatNodeDefForError(node_def.name(),
140 node_def.has_experimental_debug_info(),
141 node_def.experimental_debug_info());
142 }
143
Find(StringPiece attr_name) const144 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
145 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
146 // requires that the keys used for lookups have type 'const string&'. Because
147 // this method takes a StringPiece, it is necessary to allocate a temporary
148 // string, copy attr_name to it, and then use that temporary string for the
149 // lookup. This causes an excessive number of short-lived allocations, and for
150 // large graphs, this can be a significant cost.
151 //
152 // Because most nodes have a small number of attributes, a simple linear scan
153 // is generally more efficient than a hashed lookup. If google::protobuf::Map
154 // changes so that it supports efficient lookups using StringPiece instead of
155 // const string&, then this code could be changed to use attrs_->find() again.
156
157 for (const auto& attr : *attrs_) {
158 if (attr.first == attr_name) {
159 return &attr.second;
160 }
161 }
162 return nullptr;
163 }
164
FindByString(const string & attr_name) const165 const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
166 auto iter = attrs_->find(attr_name);
167 if (iter != attrs_->end()) {
168 return &iter->second;
169 } else {
170 return nullptr;
171 }
172 }
173
Find(StringPiece attr_name,const AttrValue ** attr_value) const174 Status AttrSlice::Find(StringPiece attr_name,
175 const AttrValue** attr_value) const {
176 *attr_value = Find(attr_name);
177 if (*attr_value != nullptr) {
178 return Status::OK();
179 }
180 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
181 // Skip AttachDef for internal attrs since it is a little bit
182 // expensive and it is common for them to correctly not be included
183 // in a NodeDef.
184 if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
185 s = AttachDef(s, *ndef_);
186 }
187 return s;
188 }
189
EqualAttrs(AttrSlice other,Scratch * scratch) const190 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
191 if (size() != other.size()) return false;
192
193 for (const auto& attr : *other.attrs_) {
194 auto iter = attrs_->find(attr.first);
195 if (iter == attrs_->end()) return false;
196 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
197 // TensorProto is a nonunique representation of Tensor. This bug will go
198 // away once AttrSlice switches over to NodeInfo.
199 iter->second.SerializeToString(&scratch->a);
200 attr.second.SerializeToString(&scratch->b);
201 if (scratch->a != scratch->b) return false;
202 }
203 return true;
204 }
205
206 // The ... is to allow the caller to inject some value validation code. Use
207 // just ; if no additional validation code is needed.
208 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
209 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
210 TYPE* value) { \
211 const AttrValue* attr_value; \
212 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
213 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
214 const auto& v = attr_value->FIELD(); \
215 __VA_ARGS__; \
216 *value = CAST; \
217 return Status::OK(); \
218 } \
219 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
220 std::vector<TYPE>* value) { \
221 const AttrValue* attr_value; \
222 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
223 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
224 value->reserve(attr_value->list().FIELD().size()); \
225 for (const auto& v : attr_value->list().FIELD()) { \
226 __VA_ARGS__; \
227 value->APPEND_OP(CAST); \
228 } \
229 return Status::OK(); \
230 }
231
232 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
233 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
234 TYPE* value) { \
235 const AttrValue* attr_value = attrs.Find(attr_name); \
236 if (attr_value == nullptr) { \
237 return false; \
238 } \
239 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
240 if (!s.ok()) { \
241 return false; \
242 } \
243 const auto& v = attr_value->FIELD(); \
244 __VA_ARGS__; \
245 *value = CAST; \
246 return true; \
247 } \
248 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
249 std::vector<TYPE>* value) { \
250 const AttrValue* attr_value = attrs.Find(attr_name); \
251 if (attr_value == nullptr) { \
252 return false; \
253 } \
254 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
255 if (!s.ok()) { \
256 return false; \
257 } \
258 value->reserve(attr_value->list().FIELD().size()); \
259 for (const auto& v : attr_value->list().FIELD()) { \
260 __VA_ARGS__; \
261 value->APPEND_OP(CAST); \
262 } \
263 return true; \
264 }
265 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
266 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
267 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
268 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
269 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
270 DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
271 DEFINE_GET_ATTR(
272 int32, i, "int", emplace_back, static_cast<int32>(v),
273 if (static_cast<int64>(static_cast<int32>(v)) != v) {
274 return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
275 " out of range for an int32");
276 })
277 DEFINE_TRY_GET_ATTR(
278 int32, i, "int", emplace_back, static_cast<int32>(v),
279 if (static_cast<int64>(static_cast<int32>(v)) != v) {
280 static int log_counter = 0;
281 if (log_counter < 10) {
282 log_counter++;
283 LOG(WARNING) << "Attr " << attr_name << " has value " << v
284 << " out of range for an int32";
285 }
286 return false;
287 })
288 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
289 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
290 // std::vector<bool> specialization does not have emplace_back until
291 // c++14, so we have to use push_back (see
292 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
293 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
294 DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
295 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
296 ;)
297 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
298 static_cast<DataType>(v),
299 ;)
300 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
301 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
302 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
303 DEFINE_TRY_GET_ATTR(
304 TensorShape, shape, "shape", emplace_back, TensorShape(v),
305 if (!TensorShape::IsValidShape(v).ok()) {
306 static int log_counter = 0;
307 if (log_counter < 10) {
308 log_counter++;
309 LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
310 << v.DebugString();
311 }
312 return false;
313 })
314 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
315 PartialTensorShape(v),
316 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
317 DEFINE_GET_ATTR(
318 Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
319 return errors::InvalidArgument("Attr ", attr_name, " has value ",
320 v.ShortDebugString(),
321 " that can't be converted to a Tensor");
322 })
323 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
324 #undef DEFINE_GET_ATTR
325
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)326 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
327 return node_def.attr().find(string(attr_name)) != node_def.attr().end();
328 }
329
330 static const string& kEmptyString = *new string();
331
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)332 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
333 const AttrValue* attr_value = attrs.Find(attr_name);
334 if (attr_value == nullptr) {
335 return kEmptyString;
336 }
337 Status s = AttrValueHasType(*attr_value, "string");
338 if (!s.ok()) {
339 return kEmptyString;
340 }
341 return attr_value->s();
342 }
343
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)344 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
345 std::vector<const string*>* value) {
346 const AttrValue* attr_value = attrs.Find(attr_name);
347 if (attr_value == nullptr) {
348 return false;
349 }
350 Status s = AttrValueHasType(*attr_value, "list(string)");
351 if (!s.ok()) {
352 return false;
353 }
354 value->reserve(attr_value->list().s().size());
355 for (const auto& v : attr_value->list().s()) {
356 value->push_back(&v);
357 }
358 return true;
359 }
360
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)361 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
362 std::vector<const TensorShapeProto*>* value) {
363 const AttrValue* attr_value = attrs.Find(attr_name);
364 if (attr_value == nullptr) {
365 return false;
366 }
367 Status s = AttrValueHasType(*attr_value, "list(shape)");
368 if (!s.ok()) {
369 return false;
370 }
371 value->reserve(attr_value->list().shape().size());
372 for (const auto& v : attr_value->list().shape()) {
373 value->push_back(&v);
374 }
375 return true;
376 }
377
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)378 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
379 DataTypeVector* value) {
380 const AttrValue* attr_value;
381 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
382 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
383 for (const auto& v : attr_value->list().type()) {
384 value->push_back(static_cast<DataType>(v));
385 }
386 return Status::OK();
387 }
388
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)389 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
390 const TensorProto** value) {
391 const AttrValue* attr_value;
392 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
393 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
394 *value = &attr_value->tensor();
395 return Status::OK();
396 }
397
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)398 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
399 const TensorProto** value) {
400 const AttrValue* attr_value = attrs.Find(attr_name);
401 if (attr_value == nullptr) {
402 return false;
403 }
404 Status s = AttrValueHasType(*attr_value, "tensor");
405 if (!s.ok()) {
406 return false;
407 }
408 *value = &attr_value->tensor();
409 return true;
410 }
411
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)412 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
413 const NameAttrList** value) {
414 const AttrValue* attr_value;
415 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
416 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
417 *value = &attr_value->func();
418 return Status::OK();
419 }
420
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)421 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
422 const NameAttrList** value) {
423 const AttrValue* attr_value = attrs.Find(attr_name);
424 if (attr_value == nullptr) {
425 return false;
426 }
427 Status s = AttrValueHasType(*attr_value, "func");
428 if (!s.ok()) {
429 return false;
430 }
431 *value = &attr_value->func();
432 return true;
433 }
434
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,Padding * value)435 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
436 Padding* value) {
437 string str_value;
438 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
439 return GetPaddingFromString(str_value, value);
440 }
441
442 namespace { // Helper for InOutTypesForNode().
443
444 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)445 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
446 const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
447 const int original_size = sig->size();
448 if (!arg_def.number_attr().empty()) {
449 // Same type repeated "repeats" times.
450 int64 repeats = -1;
451 TF_RETURN_IF_ERROR(
452 GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
453 // We can't handle outputs that are larger than int32 sizes.
454 if (static_cast<int64>(static_cast<int32>(repeats)) != repeats) {
455 return errors::InvalidArgument("Number of outputs is too big: ", repeats);
456 }
457 if (repeats < 0) {
458 return errors::InvalidArgument("Value for number_attr() ", repeats,
459 " < 0");
460 }
461
462 if (!arg_def.type_attr().empty()) {
463 DataType dtype;
464 TF_RETURN_IF_ERROR(
465 GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
466 for (int i = 0; i < repeats; ++i) {
467 sig->push_back(dtype);
468 }
469 } else if (arg_def.type() != DT_INVALID) {
470 for (int i = 0; i < repeats; ++i) {
471 sig->push_back(arg_def.type());
472 }
473 } else {
474 return errors::InvalidArgument("Missing type or type_attr field in ",
475 arg_def.ShortDebugString());
476 }
477 } else if (!arg_def.type_attr().empty()) {
478 const AttrValue* attr_value;
479 TF_RETURN_IF_ERROR(
480 AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
481 sig->push_back(attr_value->type());
482 } else if (!arg_def.type_list_attr().empty()) {
483 const AttrValue* attr_value;
484 TF_RETURN_IF_ERROR(
485 AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
486 for (int dtype : attr_value->list().type()) {
487 sig->push_back(static_cast<DataType>(dtype));
488 }
489 } else if (arg_def.type() != DT_INVALID) {
490 sig->push_back(arg_def.type());
491 } else {
492 return errors::InvalidArgument("No type fields in ",
493 arg_def.ShortDebugString());
494 }
495 if (arg_def.is_ref()) {
496 // For all types that were added by this function call, make them refs.
497 for (size_t i = original_size; i < sig->size(); ++i) {
498 if (IsRefType((*sig)[i])) {
499 return errors::InvalidArgument(
500 "Requested reference to a reference type: ",
501 arg_def.ShortDebugString());
502 }
503 (*sig)[i] = MakeRefType((*sig)[i]);
504 }
505 }
506 return Status::OK();
507 }
508
509 } // namespace
510
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)511 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
512 int input_port, DataType* input_type) {
513 DataTypeVector input_types;
514 for (const auto& arg : op_def.input_arg()) {
515 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
516 int input_types_size = input_types.size();
517 if (input_types_size > input_port) {
518 const DataType dtype = input_types[input_port];
519 *input_type = dtype;
520 return Status::OK();
521 }
522 }
523 return errors::InvalidArgument("Input ", input_port, " not found for node ",
524 node_def.name());
525 }
526
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)527 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
528 DataTypeVector* inputs) {
529 for (const auto& arg : op_def.input_arg()) {
530 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
531 }
532 return Status::OK();
533 }
534
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)535 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
536 int output_port, DataType* output_type) {
537 DataTypeVector output_types;
538 for (const auto& arg : op_def.output_arg()) {
539 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
540 int output_types_size = output_types.size();
541 if (output_types_size > output_port) {
542 const DataType dtype = output_types[output_port];
543 *output_type = dtype;
544 return Status::OK();
545 }
546 }
547 return errors::InvalidArgument("Output ", output_port, " not found for node ",
548 node_def.name());
549 }
550
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)551 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
552 DataTypeVector* outputs) {
553 for (const auto& arg : op_def.output_arg()) {
554 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
555 }
556 return Status::OK();
557 }
558
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)559 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
560 DataTypeVector* outputs) {
561 for (const auto& arg : op_def.output_arg()) {
562 TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
563 }
564 return Status::OK();
565 }
566
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)567 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
568 DataTypeVector* inputs, DataTypeVector* outputs) {
569 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
570 return OutputTypesForNode(node_def, op_def, outputs);
571 }
572
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)573 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
574 int* num_outputs) {
575 DataTypeVector outputs;
576 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
577 *num_outputs = outputs.size();
578 return Status::OK();
579 }
580
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)581 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
582 if (node_def.op() != op_def.name()) {
583 return errors::InvalidArgument(
584 "NodeDef op '", node_def.op(), "' does not match ",
585 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
586 }
587
588 bool seen_control = false;
589 size_t num_inputs = 0;
590 // TODO(josh11b): Unify the input field validation.
591 for (const string& input : node_def.input()) {
592 if (absl::StartsWith(input, "^")) {
593 seen_control = true;
594 if (input.find(':') != string::npos) {
595 return errors::InvalidArgument("Control input '", input,
596 "' must not have ':' in NodeDef: ",
597 FormatNodeDefForError(node_def));
598 }
599 } else if (seen_control) {
600 return errors::InvalidArgument("Non-control input '", input,
601 "' after control input in NodeDef: ",
602 FormatNodeDefForError(node_def));
603 } else {
604 ++num_inputs;
605 }
606 }
607
608 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
609 for (const auto& attr : op_def.attr()) {
610 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
611 return errors::InvalidArgument("OpDef has duplicate attr name '",
612 attr.name(),
613 "': ", SummarizeOpDef(op_def));
614 }
615 }
616 for (const auto& attr : node_def.attr()) {
617 // Allow internal optional attributes with names starting with "_".
618 if (absl::StartsWith(attr.first, "_")) {
619 continue;
620 }
621 auto iter = op_attrs.find(attr.first);
622 if (iter == op_attrs.end()) {
623 // A common cause of this error is that TensorFlow has made a
624 // backwards-compatible change to the NodeDef (e.g., adding a
625 // new attr with a default value), but the binary consuming the
626 // NodeDef does not know about the new attribute; the solution
627 // in these cases is to ensure that the binary consuming the
628 // NodeDef is built with a version of TensorFlow no earlier than
629 // the binary producing it.
630 return errors::InvalidArgument(
631 "NodeDef mentions attr '", attr.first, "' not in ",
632 SummarizeOpDef(op_def),
633 "; NodeDef: ", FormatNodeDefForError(node_def),
634 ". (Check whether your GraphDef-interpreting binary is up to date "
635 "with your GraphDef-generating binary.).");
636 }
637 // If attr value is placeholder, do not check it.
638 if (attr.second.placeholder().empty()) {
639 TF_RETURN_WITH_CONTEXT_IF_ERROR(
640 ValidateAttrValue(attr.second, *iter->second),
641 "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
642 SummarizeOpDef(op_def));
643 }
644 // Keep track of which attr names have (not) been found in the NodeDef.
645 op_attrs.erase(iter);
646 }
647
648 // Were all attrs in the OpDef found in the NodeDef?
649 if (!op_attrs.empty()) {
650 string attrs;
651 for (const auto& attr_pair : op_attrs) {
652 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
653 strings::StrAppend(&attrs, attr_pair.first);
654 }
655 return errors::InvalidArgument(
656 "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
657 "' from ", SummarizeOpDef(op_def),
658 "; NodeDef: ", FormatNodeDefForError(node_def));
659 }
660
661 // Validate the number of inputs.
662 DataTypeVector inputs, outputs;
663 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
664
665 if (num_inputs != inputs.size()) {
666 return errors::InvalidArgument(
667 "NodeDef expected inputs '", DataTypeVectorString(inputs),
668 "' do not match ", num_inputs, " inputs specified; ",
669 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
670 }
671
672 return Status::OK();
673 }
674
675 namespace { // Helpers for NameRangesForNode()
676
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)677 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
678 const OpDef& op_def, int* num) {
679 if (!arg_def.number_attr().empty()) {
680 // Same type repeated "num" times.
681 return GetNodeAttr(attrs, arg_def.number_attr(), num);
682 } else if (!arg_def.type_list_attr().empty()) {
683 const AttrValue* attr_value;
684 TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
685 *num = attr_value->list().type_size();
686 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
687 *num = 1;
688 } else {
689 return errors::InvalidArgument(
690 "Argument '", arg_def.name(),
691 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
692 }
693 return Status::OK();
694 }
695
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)696 Status NameRangesHelper(const AttrSlice& attrs,
697 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
698 const OpDef& op_def, NameRangeMap* result) {
699 int start = 0;
700 int num;
701 for (const auto& arg : args) {
702 TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
703 (*result)[arg.name()] = std::make_pair(start, start + num);
704 start += num;
705 }
706 return Status::OK();
707 }
708
709 } // namespace
710
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)711 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
712 NameRangeMap* inputs, NameRangeMap* outputs) {
713 if (inputs != nullptr) {
714 TF_RETURN_IF_ERROR(
715 NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
716 }
717 if (outputs != nullptr) {
718 return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
719 }
720 return Status::OK();
721 }
722
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)723 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
724 for (const auto& attr_def : op_def.attr()) {
725 AttrSlice attrs(*node_def);
726 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
727 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
728 }
729 }
730 }
731
732 namespace {
733
734 using ::tensorflow::tstring;
735 using ::tensorflow::strings::Scanner;
736
IsValidNodeName(StringPiece sp)737 bool IsValidNodeName(StringPiece sp) {
738 Scanner scanner(sp);
739 scanner.One(Scanner::LETTER_DIGIT_DOT)
740 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
741
742 while (true) {
743 if (!scanner.GetResult()) // Some error in previous iteration.
744 return false;
745 if (scanner.empty()) // No error, but nothing left, good.
746 return true;
747
748 // Absorb another name/namespace, starting with a '>'
749 scanner.One(Scanner::RANGLE)
750 .One(Scanner::LETTER_DIGIT_DOT)
751 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
752 }
753 }
754
IsValidDataInputName(StringPiece sp)755 bool IsValidDataInputName(StringPiece sp) {
756 // Data inputs are op_name, op_name:0, or op_name:12345.
757 Scanner scan(sp);
758 scan.One(Scanner::LETTER_DIGIT_DOT)
759 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
760
761 while (true) {
762 if (!scan.GetResult()) // Some error in previous iteration.
763 return false;
764 if (scan.empty()) // No error, but nothing left, good.
765 return true;
766
767 if (scan.Peek() == ':') { // Absorb identifier after the colon
768 scan.OneLiteral(":");
769 if (scan.Peek() == '0') {
770 scan.OneLiteral("0"); // :0
771 } else {
772 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
773 }
774 } else {
775 // Absorb another name/namespace, starting with a '>'
776 scan.One(Scanner::RANGLE)
777 .One(Scanner::LETTER_DIGIT_DOT)
778 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
779 }
780 }
781 }
782
IsValidControlInputName(StringPiece sp)783 bool IsValidControlInputName(StringPiece sp) {
784 Scanner scan(sp);
785 scan.OneLiteral("^")
786 .One(Scanner::LETTER_DIGIT_DOT)
787 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
788
789 while (true) {
790 if (!scan.GetResult()) // Some error in previous iteration.
791 return false;
792 if (scan.empty()) // No error, but nothing left, good.
793 return true;
794
795 // Absorb another name/namespace, starting with a '>'
796 scan.One(Scanner::RANGLE)
797 .One(Scanner::LETTER_DIGIT_DOT)
798 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
799 }
800 }
801
802 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
803
804 } // namespace
805
ValidateOpInput(const string & input_name,bool * is_control_input)806 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
807 *is_control_input = false;
808 if (IsValidDataInputName(input_name)) {
809 return Status::OK();
810 } else if (IsValidControlInputName(input_name)) {
811 *is_control_input = true;
812 return Status::OK();
813 } else {
814 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
815 }
816 }
817
ValidateNodeName(const string & node_name)818 Status ValidateNodeName(const string& node_name) {
819 if (IsValidNodeName(node_name)) {
820 return Status::OK();
821 } else {
822 return errors::InvalidArgument("Illegal op name '", node_name, "'");
823 }
824 }
825
ValidateExternalNodeDefSyntax(const NodeDef & node_def)826 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
827 Status s = ValidateNodeName(node_def.name());
828 if (!s.ok()) {
829 return AttachDef(s, node_def);
830 }
831 bool in_control_inputs = false;
832 for (const string& input_name : node_def.input()) {
833 bool is_control_input;
834 s = ValidateOpInput(input_name, &is_control_input);
835 if (!s.ok()) {
836 return AttachDef(s, node_def);
837 }
838
839 if (in_control_inputs && !is_control_input) {
840 return AttachDef(errors::InvalidArgument(
841 "All control inputs must follow all data inputs"),
842 node_def);
843 }
844 in_control_inputs = is_control_input;
845 }
846 return Status::OK();
847 }
848
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)849 Status AttachDef(const Status& status, const NodeDef& node_def,
850 bool allow_multiple_formatted_node) {
851 Status ret = status;
852 string node_error;
853 if (!allow_multiple_formatted_node &&
854 status.error_message().find("{{node ") != string::npos) {
855 node_error = node_def.name();
856 } else {
857 node_error = FormatNodeDefForError(node_def);
858 }
859 errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
860 return ret;
861 }
862
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)863 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
864 node_def->mutable_attr()->insert(
865 AttrValueMap::value_type(string(name), value));
866 }
867
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)868 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
869 (*node_def->mutable_attr())[string(name)] = std::move(value);
870 }
871
872 #define ADD_NODE_ATTR(T) \
873 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
874 AttrValue attr_value; \
875 SetAttrValue(value, &attr_value); \
876 AddNodeAttr(name, attr_value, node_def); \
877 }
878 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)879 ADD_NODE_ATTR(const char*)
880 ADD_NODE_ATTR(int32)
881 ADD_NODE_ATTR(int64)
882 ADD_NODE_ATTR(float)
883 ADD_NODE_ATTR(double)
884 ADD_NODE_ATTR(bool)
885 ADD_NODE_ATTR(DataType)
886 ADD_NODE_ATTR(const PartialTensorShape&)
887 ADD_NODE_ATTR(const Tensor&)
888 ADD_NODE_ATTR(const TensorProto&)
889 ADD_NODE_ATTR(const NameAttrList&)
890 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
891 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
892 ADD_NODE_ATTR(gtl::ArraySlice<string>)
893 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
894 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
895 ADD_NODE_ATTR(gtl::ArraySlice<float>)
896 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
897 ADD_NODE_ATTR(const std::vector<bool>&)
898 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
899 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
900 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
901 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
902 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
903 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
904 #undef ADD_NODE_ATTR
905
906 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
907 map->insert(AttrValueMap::value_type(string(name), value));
908 }
909
910 #define ADD_ATTR(T) \
911 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
912 AttrValue attr_value; \
913 SetAttrValue(value, &attr_value); \
914 AddAttr(name, attr_value, map); \
915 }
ADD_ATTR(bool)916 ADD_ATTR(bool)
917 #undef ADD_ATTR
918
919 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
920 NodeDef* node_def, bool uniquify_frame_name) {
921 node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
922
923 // Update frame name to avoid multiple LoopCond nodes in one frame.
924 if (uniquify_frame_name &&
925 (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
926 string frame_name;
927 TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
928 AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
929 frame_name = strings::StrCat(prefix, frame_name, suffix);
930 attr.set_s(frame_name);
931 }
932
933 return Status::OK();
934 }
935
MaybeAddPrefixToColocationConstraints(const std::unordered_set<string> & match,StringPiece prefix,NodeDef * node_def)936 Status MaybeAddPrefixToColocationConstraints(
937 const std::unordered_set<string>& match, StringPiece prefix,
938 NodeDef* node_def) {
939 auto attr = node_def->mutable_attr()->find(kColocationAttrName);
940 if (attr == node_def->mutable_attr()->end()) {
941 return Status::OK();
942 }
943 auto constraints_list = attr->second.mutable_list();
944 auto constraints_size = constraints_list->s_size();
945 for (size_t i = 0; i < constraints_size; ++i) {
946 StringPiece original(constraints_list->s(i));
947 if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
948 if (match.find(string(original)) != match.end()) {
949 (*constraints_list->mutable_s(i)) =
950 strings::StrCat(kColocationGroupPrefix, prefix, original);
951 }
952 }
953 }
954 return Status::OK();
955 }
956
957 } // namespace tensorflow
958