1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/graph.pb_text.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_def.pb_text.h"
27 #include "tensorflow/core/framework/op_def_util.h"
28 #include "tensorflow/core/framework/tensor.pb_text.h"
29 #include "tensorflow/core/framework/tensor_shape.pb.h"
30 #include "tensorflow/core/graph/graph.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/gtl/map_util.h"
33 #include "tensorflow/core/lib/strings/scanner.h"
34 #include "tensorflow/core/lib/strings/str_util.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/protobuf.h"
37
38 namespace tensorflow {
39
40 const char* const kColocationAttrName = "_class";
41 const char* const kColocationGroupPrefix = "loc:@";
42
AttrSlice()43 AttrSlice::AttrSlice() : ndef_(nullptr) {
44 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
45 attrs_ = kEmptyAttrValueMap;
46 }
47
AttrSlice(const NodeDef & node_def)48 AttrSlice::AttrSlice(const NodeDef& node_def)
49 : ndef_(&node_def), attrs_(&ndef_->attr()) {}
50
AttrSlice(const AttrValueMap * a)51 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
52
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)53 static string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
54 string ret;
55
56 // We sort the attrs so the output is deterministic.
57 std::vector<string> attr_names;
58 attr_names.reserve(attrs.size());
59 for (const auto& attr : attrs) {
60 attr_names.push_back(attr.first);
61 }
62 std::sort(attr_names.begin(), attr_names.end());
63 bool first = true;
64 for (const string& attr_name : attr_names) {
65 if (!first) strings::StrAppend(&ret, ", ");
66 first = false;
67 strings::StrAppend(&ret, attr_name, "=",
68 SummarizeAttrValue(*attrs.Find(attr_name)));
69 }
70
71 // Consider the device to be a final attr with name "_device".
72 if (!device.empty()) {
73 if (!first) strings::StrAppend(&ret, ", ");
74 first = false;
75 strings::StrAppend(&ret, "_device=\"", device, "\"");
76 }
77 return ret;
78 }
79
SummarizeNode() const80 string AttrSlice::SummarizeNode() const {
81 return ndef_ ? SummarizeNodeDef(*ndef_)
82 : strings::StrCat(
83 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
84 }
85
SummarizeNode(const Node & node)86 string SummarizeNode(const Node& node) { return SummarizeNodeDef(node.def()); }
87
SummarizeNodeDef(const NodeDef & node_def)88 string SummarizeNodeDef(const NodeDef& node_def) {
89 string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
90 " = ", node_def.op(), "[");
91 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
92 strings::StrAppend(&ret, "](");
93
94 // Output inputs, including control inputs, verbatim.
95 bool first = true;
96 for (const string& input : node_def.input()) {
97 if (!first) strings::StrAppend(&ret, ", ");
98 first = false;
99 strings::StrAppend(&ret, input);
100 }
101 strings::StrAppend(&ret, ")");
102 return ret;
103 }
104
SummarizeAttrs(const NodeDef & node_def)105 string SummarizeAttrs(const NodeDef& node_def) {
106 return SummarizeAttrsHelper(node_def, node_def.device());
107 }
108
FormatNodeForError(const NodeDebugInfo & debug_info)109 string FormatNodeForError(const NodeDebugInfo& debug_info) {
110 return debug_info.original_node_names.empty()
111 ? errors::FormatNodeNameForError(debug_info.name)
112 : errors::FormatNodeNamesForError(debug_info.original_node_names);
113 }
114
FormatNodeForError(const Node & node)115 string FormatNodeForError(const Node& node) {
116 return FormatNodeForError(NodeDebugInfo(node));
117 }
118
FormatNodeDefForError(const NodeDef & node_def)119 string FormatNodeDefForError(const NodeDef& node_def) {
120 return FormatNodeForError(NodeDebugInfo(node_def));
121 }
122
GetMergedOriginalNodeNames(const NodeDebugInfo & from,const NodeDebugInfo & to,std::set<string> * names)123 void GetMergedOriginalNodeNames(const NodeDebugInfo& from,
124 const NodeDebugInfo& to,
125 std::set<string>* names) {
126 if (!from.original_node_names.empty()) {
127 names->insert(from.original_node_names.begin(),
128 from.original_node_names.end());
129 } else {
130 names->insert(from.name);
131 }
132 names->insert(to.original_node_names.begin(), to.original_node_names.end());
133 }
134
MergeDebugInfo(const NodeDebugInfo & from,Node * to)135 void MergeDebugInfo(const NodeDebugInfo& from, Node* to) {
136 std::set<string> names;
137 GetMergedOriginalNodeNames(from, NodeDebugInfo(*to), &names);
138 to->set_original_node_names({names.begin(), names.end()});
139 }
140
MergeDebugInfo(const NodeDebugInfo & from,NodeDef * to)141 void MergeDebugInfo(const NodeDebugInfo& from, NodeDef* to) {
142 std::set<string> names;
143 GetMergedOriginalNodeNames(from, NodeDebugInfo(*to), &names);
144 to->mutable_experimental_debug_info()->clear_original_node_names();
145 if (!names.empty()) {
146 *to->mutable_experimental_debug_info()->mutable_original_node_names() = {
147 names.begin(), names.end()};
148 }
149 }
150
MergeDebugInfo(const NodeDef & from,NodeDef * to)151 void MergeDebugInfo(const NodeDef& from, NodeDef* to) {
152 MergeDebugInfo(NodeDebugInfo(from), to);
153 }
154
Find(StringPiece attr_name) const155 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
156 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
157 // requires that the keys used for lookups have type 'const string&'. Because
158 // this method takes a StringPiece, it is necessary to allocate a temporary
159 // string, copy attr_name to it, and then use that temporary string for the
160 // lookup. This causes an excessive number of short-lived allocations, and for
161 // large graphs, this can be a significant cost.
162 //
163 // Because most nodes have a small number of attributes, a simple linear scan
164 // is generally more efficient than a hashed lookup. If google::protobuf::Map
165 // changes so that it supports efficient lookups using StringPiece instead of
166 // const string&, then this code could be changed to use attrs_->find() again.
167
168 for (const auto& attr : *attrs_) {
169 if (attr.first == attr_name) {
170 return &attr.second;
171 }
172 }
173 return nullptr;
174 }
175
Find(StringPiece attr_name,const AttrValue ** attr_value) const176 Status AttrSlice::Find(StringPiece attr_name,
177 const AttrValue** attr_value) const {
178 *attr_value = Find(attr_name);
179 if (*attr_value != nullptr) {
180 return Status::OK();
181 }
182 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
183 // Skip AttachDef for internal attrs since it is a little bit
184 // expensive and it is common for them to correctly not be included
185 // in a NodeDef.
186 if (!str_util::StartsWith(attr_name, "_") && ndef_ != nullptr) {
187 s = AttachDef(s, *ndef_);
188 }
189 return s;
190 }
191
EqualAttrs(AttrSlice other,Scratch * scratch) const192 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
193 if (size() != other.size()) return false;
194
195 for (const auto& attr : *other.attrs_) {
196 auto iter = attrs_->find(attr.first);
197 if (iter == attrs_->end()) return false;
198 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
199 // TensorProto is a nonunique representation of Tensor. This bug will go
200 // away once AttrSlice switches over to NodeInfo.
201 iter->second.SerializeToString(&scratch->a);
202 attr.second.SerializeToString(&scratch->b);
203 if (scratch->a != scratch->b) return false;
204 }
205 return true;
206 }
207
208 // The ... is to allow the caller to inject some value validation code. Use
209 // just ; if no additional validation code is needed.
210 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
211 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
212 TYPE* value) { \
213 const AttrValue* attr_value; \
214 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
215 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
216 const auto& v = attr_value->FIELD(); \
217 __VA_ARGS__; \
218 *value = CAST; \
219 return Status::OK(); \
220 } \
221 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
222 std::vector<TYPE>* value) { \
223 const AttrValue* attr_value; \
224 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
225 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
226 for (const auto& v : attr_value->list().FIELD()) { \
227 __VA_ARGS__; \
228 value->APPEND_OP(CAST); \
229 } \
230 return Status::OK(); \
231 }
232
233 #define DEFINE_GET_ATTR_SIMPLE(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
234 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \
235 TYPE* value) { \
236 const AttrValue* attr_value = attrs.Find(attr_name); \
237 if (attr_value == nullptr) { \
238 return false; \
239 } \
240 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
241 if (!s.ok()) { \
242 return false; \
243 } \
244 const auto& v = attr_value->FIELD(); \
245 __VA_ARGS__; \
246 *value = CAST; \
247 return true; \
248 } \
249 bool GetNodeAttrSimple(const AttrSlice& attrs, StringPiece attr_name, \
250 std::vector<TYPE>* value) { \
251 const AttrValue* attr_value = attrs.Find(attr_name); \
252 if (attr_value == nullptr) { \
253 return false; \
254 } \
255 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
256 if (!s.ok()) { \
257 return false; \
258 } \
259 for (const auto& v : attr_value->list().FIELD()) { \
260 __VA_ARGS__; \
261 value->APPEND_OP(CAST); \
262 } \
263 return true; \
264 }
265
266 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
267 DEFINE_GET_ATTR_SIMPLE(string, s, "string", emplace_back, v, ;)
268 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
269 DEFINE_GET_ATTR(int32, i, "int", emplace_back, static_cast<int32>(v),
270 if (static_cast<int64>(static_cast<int32>(v)) != v) {
271 return errors::InvalidArgument("Attr ", attr_name,
272 " has value ", v,
273 " out of range for an int32");
274 })
275 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
276 // std::vector<bool> specialization does not have emplace_back until
277 // c++14, so we have to use push_back (see
278 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
279 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
280 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
281 ;)
282 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
283 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
284 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
285 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
286 PartialTensorShape(v),
287 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
288 DEFINE_GET_ATTR(Tensor, tensor, "tensor", emplace_back, t, Tensor t;
289 if (!t.FromProto(v)) {
290 return errors::InvalidArgument(
291 "Attr ", attr_name, " has value ",
292 ProtoShortDebugString(v),
293 " that can't be converted to a Tensor");
294 })
295 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
296 #undef DEFINE_GET_ATTR
297
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)298 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
299 return node_def.attr().find(string(attr_name)) != node_def.attr().end();
300 }
301
302 static const string& kEmptyString = *new string();
303
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)304 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
305 const AttrValue* attr_value = attrs.Find(attr_name);
306 if (attr_value == nullptr) {
307 return kEmptyString;
308 }
309 Status s = AttrValueHasType(*attr_value, "string");
310 if (!s.ok()) {
311 return kEmptyString;
312 }
313 return attr_value->s();
314 }
315
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)316 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
317 DataTypeVector* value) {
318 const AttrValue* attr_value;
319 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
320 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
321 for (const auto& v : attr_value->list().type()) {
322 value->push_back(static_cast<DataType>(v));
323 }
324 return Status::OK();
325 }
326
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)327 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
328 const TensorProto** value) {
329 const AttrValue* attr_value;
330 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
331 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
332 *value = &attr_value->tensor();
333 return Status::OK();
334 }
335
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)336 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
337 const NameAttrList** value) {
338 const AttrValue* attr_value;
339 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
340 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
341 *value = &attr_value->func();
342 return Status::OK();
343 }
344
345 namespace { // Helper for InOutTypesForNode().
346
AddArgToSig(const NodeDef & node_def,const OpDef::ArgDef & arg_def,DataTypeVector * sig)347 Status AddArgToSig(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
348 DataTypeVector* sig) {
349 const int original_size = sig->size();
350 if (!arg_def.number_attr().empty()) {
351 // Same type repeated "repeats" times.
352 int32 repeats = -1;
353 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.number_attr(), &repeats));
354 if (repeats < 0) {
355 return errors::InvalidArgument("Value for number_attr() ", repeats,
356 " < 0");
357 }
358
359 if (!arg_def.type_attr().empty()) {
360 DataType dtype;
361 TF_RETURN_IF_ERROR(GetNodeAttr(node_def, arg_def.type_attr(), &dtype));
362 for (int i = 0; i < repeats; ++i) {
363 sig->push_back(dtype);
364 }
365 } else if (arg_def.type() != DT_INVALID) {
366 for (int i = 0; i < repeats; ++i) {
367 sig->push_back(arg_def.type());
368 }
369 } else {
370 return errors::InvalidArgument("Missing type or type_attr field in ",
371 ProtoShortDebugString(arg_def));
372 }
373 } else if (!arg_def.type_attr().empty()) {
374 const AttrValue* attr_value;
375 TF_RETURN_IF_ERROR(
376 AttrSlice(node_def).Find(arg_def.type_attr(), &attr_value));
377 sig->push_back(attr_value->type());
378 } else if (!arg_def.type_list_attr().empty()) {
379 const AttrValue* attr_value;
380 TF_RETURN_IF_ERROR(
381 AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
382 for (int dtype : attr_value->list().type()) {
383 sig->push_back(static_cast<DataType>(dtype));
384 }
385 } else if (arg_def.type() != DT_INVALID) {
386 sig->push_back(arg_def.type());
387 } else {
388 return errors::InvalidArgument("No type fields in ",
389 ProtoShortDebugString(arg_def));
390 }
391 if (arg_def.is_ref()) {
392 // For all types that were added by this function call, make them refs.
393 for (size_t i = original_size; i < sig->size(); ++i) {
394 (*sig)[i] = MakeRefType((*sig)[i]);
395 }
396 }
397 return Status::OK();
398 }
399
400 } // namespace
401
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)402 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
403 int input_port, DataType* input_type) {
404 DataTypeVector input_types;
405 for (const auto& arg : op_def.input_arg()) {
406 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
407 if (input_types.size() > input_port) {
408 const DataType dtype = input_types[input_port];
409 *input_type = dtype;
410 return Status::OK();
411 }
412 }
413 return errors::InvalidArgument("Input ", input_port, " not found for node ",
414 node_def.name());
415 }
416
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)417 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
418 DataTypeVector* inputs) {
419 for (const auto& arg : op_def.input_arg()) {
420 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
421 }
422 return Status::OK();
423 }
424
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)425 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
426 int output_port, DataType* output_type) {
427 DataTypeVector output_types;
428 for (const auto& arg : op_def.output_arg()) {
429 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
430 if (output_types.size() > output_port) {
431 const DataType dtype = output_types[output_port];
432 *output_type = dtype;
433 return Status::OK();
434 }
435 }
436 return errors::InvalidArgument("Output ", output_port, " not found for node ",
437 node_def.name());
438 }
439
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)440 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
441 DataTypeVector* outputs) {
442 for (const auto& arg : op_def.output_arg()) {
443 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
444 }
445 return Status::OK();
446 }
447
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)448 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
449 DataTypeVector* inputs, DataTypeVector* outputs) {
450 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
451 return OutputTypesForNode(node_def, op_def, outputs);
452 }
453
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)454 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
455 int* num_outputs) {
456 DataTypeVector outputs;
457 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
458 *num_outputs = outputs.size();
459 return Status::OK();
460 }
461
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)462 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
463 if (node_def.op() != op_def.name()) {
464 return errors::InvalidArgument(
465 "NodeDef op '", node_def.op(), "' does not match ",
466 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
467 }
468
469 bool seen_control = false;
470 size_t num_inputs = 0;
471 // TODO(josh11b): Unify the input field validation.
472 for (const string& input : node_def.input()) {
473 if (str_util::StartsWith(input, "^")) {
474 seen_control = true;
475 if (input.find(':') != string::npos) {
476 return errors::InvalidArgument("Control input '", input,
477 "' must not have ':' in NodeDef: ",
478 FormatNodeDefForError(node_def));
479 }
480 } else if (seen_control) {
481 return errors::InvalidArgument("Non-control input '", input,
482 "' after control input in NodeDef: ",
483 FormatNodeDefForError(node_def));
484 } else {
485 ++num_inputs;
486 }
487 }
488
489 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
490 for (const auto& attr : op_def.attr()) {
491 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
492 return errors::InvalidArgument("OpDef has duplicate attr name '",
493 attr.name(),
494 "': ", SummarizeOpDef(op_def));
495 }
496 }
497 for (const auto& attr : node_def.attr()) {
498 // Allow internal optional attributes with names starting with "_".
499 if (str_util::StartsWith(attr.first, "_")) {
500 continue;
501 }
502 auto iter = op_attrs.find(attr.first);
503 if (iter == op_attrs.end()) {
504 // A common cause of this error is that TensorFlow has made a
505 // backwards-compatible change to the NodeDef (e.g., adding a
506 // new attr with a default value), but the binary consuming the
507 // NodeDef does not know about the new attribute; the solution
508 // in these cases is to ensure that the binary consuming the
509 // NodeDef is built with a version of TensorFlow no earlier than
510 // the binary producing it.
511 return errors::InvalidArgument(
512 "NodeDef mentions attr '", attr.first, "' not in ",
513 SummarizeOpDef(op_def),
514 "; NodeDef: ", FormatNodeDefForError(node_def),
515 ". (Check whether your GraphDef-interpreting binary is up to date "
516 "with your GraphDef-generating binary.).");
517 }
518 // If attr value is placeholder, do not check it.
519 if (attr.second.placeholder().empty()) {
520 TF_RETURN_WITH_CONTEXT_IF_ERROR(
521 ValidateAttrValue(attr.second, *iter->second),
522 "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
523 SummarizeOpDef(op_def));
524 }
525 // Keep track of which attr names have (not) been found in the NodeDef.
526 op_attrs.erase(iter);
527 }
528
529 // Were all attrs in the OpDef found in the NodeDef?
530 if (!op_attrs.empty()) {
531 string attrs;
532 for (const auto& attr_pair : op_attrs) {
533 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
534 strings::StrAppend(&attrs, attr_pair.first);
535 }
536 return errors::InvalidArgument(
537 "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
538 "' from ", SummarizeOpDef(op_def),
539 "; NodeDef: ", FormatNodeDefForError(node_def));
540 }
541
542 // Validate the number of inputs.
543 DataTypeVector inputs, outputs;
544 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
545
546 if (num_inputs != inputs.size()) {
547 return errors::InvalidArgument(
548 "NodeDef expected inputs '", DataTypeVectorString(inputs),
549 "' do not match ", num_inputs, " inputs specified; ",
550 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
551 }
552
553 return Status::OK();
554 }
555
556 namespace { // Helpers for NameRangesForNode()
557
ComputeArgRange(const NodeDef & node_def,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)558 Status ComputeArgRange(const NodeDef& node_def, const OpDef::ArgDef& arg_def,
559 const OpDef& op_def, int* num) {
560 if (!arg_def.number_attr().empty()) {
561 // Same type repeated "num" times.
562 return GetNodeAttr(node_def, arg_def.number_attr(), num);
563 } else if (!arg_def.type_list_attr().empty()) {
564 const AttrValue* attr_value;
565 TF_RETURN_IF_ERROR(
566 AttrSlice(node_def).Find(arg_def.type_list_attr(), &attr_value));
567 *num = attr_value->list().type_size();
568 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
569 *num = 1;
570 } else {
571 return errors::InvalidArgument(
572 "Argument '", arg_def.name(),
573 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
574 }
575 return Status::OK();
576 }
577
NameRangesHelper(const NodeDef & node_def,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)578 Status NameRangesHelper(const NodeDef& node_def,
579 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
580 const OpDef& op_def, NameRangeMap* result) {
581 int start = 0;
582 int num;
583 for (const auto& arg : args) {
584 TF_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num));
585 (*result)[arg.name()] = std::make_pair(start, start + num);
586 start += num;
587 }
588 return Status::OK();
589 }
590
591 } // namespace
592
NameRangesForNode(const NodeDef & node_def,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)593 Status NameRangesForNode(const NodeDef& node_def, const OpDef& op_def,
594 NameRangeMap* inputs, NameRangeMap* outputs) {
595 if (inputs != nullptr) {
596 TF_RETURN_IF_ERROR(
597 NameRangesHelper(node_def, op_def.input_arg(), op_def, inputs));
598 }
599 if (outputs != nullptr) {
600 return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs);
601 }
602 return Status::OK();
603 }
604
NameRangesForNode(const Node & node,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)605 Status NameRangesForNode(const Node& node, const OpDef& op_def,
606 NameRangeMap* inputs, NameRangeMap* outputs) {
607 return NameRangesForNode(node.def(), op_def, inputs, outputs);
608 }
609
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)610 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
611 for (const auto& attr_def : op_def.attr()) {
612 AttrSlice attrs(*node_def);
613 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
614 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
615 }
616 }
617 }
618
619 namespace {
620
621 using ::tensorflow::strings::Scanner;
622
IsValidOpName(StringPiece sp)623 bool IsValidOpName(StringPiece sp) {
624 return Scanner(sp)
625 .One(Scanner::LETTER_DIGIT_DOT)
626 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
627 .Eos()
628 .GetResult();
629 }
630
IsValidDataInputName(StringPiece sp)631 bool IsValidDataInputName(StringPiece sp) {
632 // Data inputs are op_name, op_name:0, or op_name:12345.
633 Scanner scan(sp);
634 scan.One(Scanner::LETTER_DIGIT_DOT)
635 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
636 if (scan.Peek() == ':') {
637 scan.OneLiteral(":");
638 if (scan.Peek() == '0') {
639 scan.OneLiteral("0"); // :0
640 } else {
641 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
642 }
643 }
644 scan.Eos();
645
646 return scan.GetResult();
647 }
648
IsValidControlInputName(StringPiece sp)649 bool IsValidControlInputName(StringPiece sp) {
650 return Scanner(sp)
651 .OneLiteral("^")
652 .One(Scanner::LETTER_DIGIT_DOT)
653 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
654 .Eos()
655 .GetResult();
656 }
657
658 } // namespace
659
ValidateOpInput(const string & input_name,bool * is_control_input)660 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
661 *is_control_input = false;
662 if (IsValidDataInputName(input_name)) {
663 return Status::OK();
664 } else if (IsValidControlInputName(input_name)) {
665 *is_control_input = true;
666 return Status::OK();
667 } else {
668 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
669 }
670 }
671
ValidateOpName(const string & op_name)672 Status ValidateOpName(const string& op_name) {
673 if (IsValidOpName(op_name)) {
674 return Status::OK();
675 } else {
676 return errors::InvalidArgument("Illegal op name '", op_name, "'");
677 }
678 }
679
ValidateExternalNodeDefSyntax(const NodeDef & node_def)680 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
681 Status s = ValidateOpName(node_def.name());
682 if (!s.ok()) {
683 return AttachDef(s, node_def);
684 }
685 bool in_control_inputs = false;
686 for (const string& input_name : node_def.input()) {
687 bool is_control_input;
688 s = ValidateOpInput(input_name, &is_control_input);
689 if (!s.ok()) {
690 return AttachDef(s, node_def);
691 }
692
693 if (in_control_inputs && !is_control_input) {
694 return AttachDef(errors::InvalidArgument(
695 "All control inputs must follow all data inputs"),
696 node_def);
697 }
698 in_control_inputs = is_control_input;
699 }
700 return Status::OK();
701 }
702
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)703 Status AttachDef(const Status& status, const NodeDef& node_def,
704 bool allow_multiple_formatted_node) {
705 Status ret = status;
706 string node_error;
707 if (!allow_multiple_formatted_node &&
708 status.error_message().find("{{node ") != string::npos) {
709 node_error = node_def.name();
710 } else {
711 node_error = FormatNodeDefForError(node_def);
712 }
713 errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
714 return ret;
715 }
716
AttachDef(const Status & status,const Node & node,bool allow_multiple_formatted_node)717 Status AttachDef(const Status& status, const Node& node,
718 bool allow_multiple_formatted_node) {
719 return AttachDef(status, node.def(), allow_multiple_formatted_node);
720 }
721
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)722 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
723 node_def->mutable_attr()->insert(
724 AttrValueMap::value_type(string(name), value));
725 }
726
727 #define ADD_NODE_ATTR(T) \
728 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
729 AttrValue attr_value; \
730 SetAttrValue(value, &attr_value); \
731 AddNodeAttr(name, attr_value, node_def); \
732 }
733 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)734 ADD_NODE_ATTR(const char*)
735 ADD_NODE_ATTR(int32)
736 ADD_NODE_ATTR(int64)
737 ADD_NODE_ATTR(float)
738 ADD_NODE_ATTR(double)
739 ADD_NODE_ATTR(bool)
740 ADD_NODE_ATTR(DataType)
741 ADD_NODE_ATTR(const PartialTensorShape&)
742 ADD_NODE_ATTR(const Tensor&)
743 ADD_NODE_ATTR(const TensorProto&)
744 ADD_NODE_ATTR(const NameAttrList&)
745 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
746 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
747 ADD_NODE_ATTR(gtl::ArraySlice<string>)
748 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
749 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
750 ADD_NODE_ATTR(gtl::ArraySlice<float>)
751 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
752 ADD_NODE_ATTR(const std::vector<bool>&)
753 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
754 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
755 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
756 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
757 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
758 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
759 #undef ADD_NODE_ATTR
760
761 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
762 map->insert(AttrValueMap::value_type(string(name), value));
763 }
764
765 #define ADD_ATTR(T) \
766 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
767 AttrValue attr_value; \
768 SetAttrValue(value, &attr_value); \
769 AddAttr(name, attr_value, map); \
770 }
ADD_ATTR(bool)771 ADD_ATTR(bool)
772 #undef ADD_ATTR
773
774 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
775 NodeDef* node_def) {
776 node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
777 if (node_def->op() == "Enter" || node_def->op() == "RefEnter") {
778 string frame_name;
779 TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
780 AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
781 frame_name = strings::StrCat(prefix, frame_name, suffix);
782 attr.set_s(frame_name);
783 }
784 return Status::OK();
785 }
786
787 } // namespace tensorflow
788