1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def_util.h"
17
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace tensorflow {
44
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47
AttrSlice()48 AttrSlice::AttrSlice() : ndef_(nullptr) {
49 static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
50 attrs_ = kEmptyAttrValueMap;
51 }
52
AttrSlice(const NodeDef & node_def)53 AttrSlice::AttrSlice(const NodeDef& node_def)
54 : ndef_(&node_def), attrs_(&ndef_->attr()) {}
55
AttrSlice(const AttrValueMap * a)56 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
57
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)58 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
59 string ret;
60
61 // We sort the attrs so the output is deterministic.
62 std::vector<string> attr_names;
63 attr_names.reserve(attrs.size());
64 for (const auto& attr : attrs) {
65 attr_names.push_back(attr.first);
66 }
67 std::sort(attr_names.begin(), attr_names.end());
68 bool first = true;
69 for (const string& attr_name : attr_names) {
70 if (!first) strings::StrAppend(&ret, ", ");
71 first = false;
72 strings::StrAppend(&ret, attr_name, "=",
73 SummarizeAttrValue(*attrs.Find(attr_name)));
74 }
75
76 // Consider the device to be a final attr with name "_device".
77 if (!device.empty()) {
78 if (!first) strings::StrAppend(&ret, ", ");
79 first = false;
80 strings::StrAppend(&ret, "_device=\"", device, "\"");
81 }
82 return ret;
83 }
84
SummarizeNode() const85 string AttrSlice::SummarizeNode() const {
86 return ndef_ ? SummarizeNodeDef(*ndef_)
87 : strings::StrCat(
88 "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
89 }
90
DebugString() const91 string AttrSlice::DebugString() const {
92 std::vector<string> attr_key_vals;
93 attr_key_vals.reserve(attrs_->size());
94 for (const auto& it : *this) {
95 const string& name = it.first;
96 const AttrValue& attr_value = it.second;
97 attr_key_vals.push_back(
98 absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
99 }
100 return absl::StrJoin(attr_key_vals, ", ");
101 }
102
SummarizeNodeDef(const NodeDef & node_def)103 string SummarizeNodeDef(const NodeDef& node_def) {
104 string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
105 " = ", node_def.op(), "[");
106 strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
107 strings::StrAppend(&ret, "](");
108
109 // Output inputs, including control inputs, verbatim.
110 bool first = true;
111 for (const string& input : node_def.input()) {
112 if (!first) strings::StrAppend(&ret, ", ");
113 first = false;
114 strings::StrAppend(&ret, input);
115 }
116 strings::StrAppend(&ret, ")");
117 return ret;
118 }
119
SummarizeAttrs(const NodeDef & node_def)120 string SummarizeAttrs(const NodeDef& node_def) {
121 return SummarizeAttrsHelper(node_def, node_def.device());
122 }
123
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)124 string FormatNodeDefForError(
125 StringPiece node_name, bool has_experimental_debug_info,
126 const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
127 return !has_experimental_debug_info ||
128 experimental_debug_info.original_node_names().empty()
129 ? errors::FormatNodeNameForError(string(node_name))
130 : errors::FormatNodeNamesForError(
131 experimental_debug_info.original_node_names());
132 }
133
FormatNodeDefForError(const NodeDef & node_def)134 string FormatNodeDefForError(const NodeDef& node_def) {
135 return FormatNodeDefForError(node_def.name(),
136 node_def.has_experimental_debug_info(),
137 node_def.experimental_debug_info());
138 }
139
Find(StringPiece attr_name) const140 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
141 // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
142 // requires that the keys used for lookups have type 'const string&'. Because
143 // this method takes a StringPiece, it is necessary to allocate a temporary
144 // string, copy attr_name to it, and then use that temporary string for the
145 // lookup. This causes an excessive number of short-lived allocations, and for
146 // large graphs, this can be a significant cost.
147 //
148 // Because most nodes have a small number of attributes, a simple linear scan
149 // is generally more efficient than a hashed lookup. If google::protobuf::Map
150 // changes so that it supports efficient lookups using StringPiece instead of
151 // const string&, then this code could be changed to use attrs_->find() again.
152
153 for (const auto& attr : *attrs_) {
154 if (attr.first == attr_name) {
155 return &attr.second;
156 }
157 }
158 return nullptr;
159 }
160
Find(StringPiece attr_name,const AttrValue ** attr_value) const161 Status AttrSlice::Find(StringPiece attr_name,
162 const AttrValue** attr_value) const {
163 *attr_value = Find(attr_name);
164 if (*attr_value != nullptr) {
165 return Status::OK();
166 }
167 Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
168 // Skip AttachDef for internal attrs since it is a little bit
169 // expensive and it is common for them to correctly not be included
170 // in a NodeDef.
171 if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
172 s = AttachDef(s, *ndef_);
173 }
174 return s;
175 }
176
EqualAttrs(AttrSlice other,Scratch * scratch) const177 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
178 if (size() != other.size()) return false;
179
180 for (const auto& attr : *other.attrs_) {
181 auto iter = attrs_->find(attr.first);
182 if (iter == attrs_->end()) return false;
183 // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
184 // TensorProto is a nonunique representation of Tensor. This bug will go
185 // away once AttrSlice switches over to NodeInfo.
186 iter->second.SerializeToString(&scratch->a);
187 attr.second.SerializeToString(&scratch->b);
188 if (scratch->a != scratch->b) return false;
189 }
190 return true;
191 }
192
193 // The ... is to allow the caller to inject some value validation code. Use
194 // just ; if no additional validation code is needed.
195 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
196 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
197 TYPE* value) { \
198 const AttrValue* attr_value; \
199 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
200 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE)); \
201 const auto& v = attr_value->FIELD(); \
202 __VA_ARGS__; \
203 *value = CAST; \
204 return Status::OK(); \
205 } \
206 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
207 std::vector<TYPE>* value) { \
208 const AttrValue* attr_value; \
209 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value)); \
210 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
211 value->reserve(attr_value->list().FIELD().size()); \
212 for (const auto& v : attr_value->list().FIELD()) { \
213 __VA_ARGS__; \
214 value->APPEND_OP(CAST); \
215 } \
216 return Status::OK(); \
217 }
218
219 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
220 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
221 TYPE* value) { \
222 const AttrValue* attr_value = attrs.Find(attr_name); \
223 if (attr_value == nullptr) { \
224 return false; \
225 } \
226 Status s = AttrValueHasType(*attr_value, ATTR_TYPE); \
227 if (!s.ok()) { \
228 return false; \
229 } \
230 const auto& v = attr_value->FIELD(); \
231 __VA_ARGS__; \
232 *value = CAST; \
233 return true; \
234 } \
235 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, \
236 std::vector<TYPE>* value) { \
237 const AttrValue* attr_value = attrs.Find(attr_name); \
238 if (attr_value == nullptr) { \
239 return false; \
240 } \
241 Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")"); \
242 if (!s.ok()) { \
243 return false; \
244 } \
245 value->reserve(attr_value->list().FIELD().size()); \
246 for (const auto& v : attr_value->list().FIELD()) { \
247 __VA_ARGS__; \
248 value->APPEND_OP(CAST); \
249 } \
250 return true; \
251 }
252 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
253 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
254 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
255 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
256 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
257 DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
258 DEFINE_GET_ATTR(
259 int32, i, "int", emplace_back, static_cast<int32>(v),
260 if (static_cast<int64>(static_cast<int32>(v)) != v) {
261 return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
262 " out of range for an int32");
263 })
264 DEFINE_TRY_GET_ATTR(
265 int32, i, "int", emplace_back, static_cast<int32>(v),
266 if (static_cast<int64>(static_cast<int32>(v)) != v) {
267 static int log_counter = 0;
268 if (log_counter < 10) {
269 log_counter++;
270 LOG(WARNING) << "Attr " << attr_name << " has value " << v
271 << " out of range for an int32";
272 }
273 return false;
274 })
275 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
276 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
277 // std::vector<bool> specialization does not have emplace_back until
278 // c++14, so we have to use push_back (see
279 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
280 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
281 DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
282 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
283 ;)
284 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
285 static_cast<DataType>(v),
286 ;)
287 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
288 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
289 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
290 DEFINE_TRY_GET_ATTR(
291 TensorShape, shape, "shape", emplace_back, TensorShape(v),
292 if (!TensorShape::IsValidShape(v).ok()) {
293 static int log_counter = 0;
294 if (log_counter < 10) {
295 log_counter++;
296 LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
297 << v.DebugString();
298 }
299 return false;
300 })
301 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
302 PartialTensorShape(v),
303 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
304 DEFINE_GET_ATTR(
305 Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
306 return errors::InvalidArgument("Attr ", attr_name, " has value ",
307 v.ShortDebugString(),
308 " that can't be converted to a Tensor");
309 })
310 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
311 #undef DEFINE_GET_ATTR
312
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)313 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
314 return node_def.attr().find(string(attr_name)) != node_def.attr().end();
315 }
316
317 static const string& kEmptyString = *new string();
318
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)319 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
320 const AttrValue* attr_value = attrs.Find(attr_name);
321 if (attr_value == nullptr) {
322 return kEmptyString;
323 }
324 Status s = AttrValueHasType(*attr_value, "string");
325 if (!s.ok()) {
326 return kEmptyString;
327 }
328 return attr_value->s();
329 }
330
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)331 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
332 std::vector<const string*>* value) {
333 const AttrValue* attr_value = attrs.Find(attr_name);
334 if (attr_value == nullptr) {
335 return false;
336 }
337 Status s = AttrValueHasType(*attr_value, "list(string)");
338 if (!s.ok()) {
339 return false;
340 }
341 value->reserve(attr_value->list().s().size());
342 for (const auto& v : attr_value->list().s()) {
343 value->push_back(&v);
344 }
345 return true;
346 }
347
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)348 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
349 std::vector<const TensorShapeProto*>* value) {
350 const AttrValue* attr_value = attrs.Find(attr_name);
351 if (attr_value == nullptr) {
352 return false;
353 }
354 Status s = AttrValueHasType(*attr_value, "list(shape)");
355 if (!s.ok()) {
356 return false;
357 }
358 value->reserve(attr_value->list().shape().size());
359 for (const auto& v : attr_value->list().shape()) {
360 value->push_back(&v);
361 }
362 return true;
363 }
364
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)365 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
366 DataTypeVector* value) {
367 const AttrValue* attr_value;
368 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
369 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
370 for (const auto& v : attr_value->list().type()) {
371 value->push_back(static_cast<DataType>(v));
372 }
373 return Status::OK();
374 }
375
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)376 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
377 const TensorProto** value) {
378 const AttrValue* attr_value;
379 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
380 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
381 *value = &attr_value->tensor();
382 return Status::OK();
383 }
384
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)385 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
386 const TensorProto** value) {
387 const AttrValue* attr_value = attrs.Find(attr_name);
388 if (attr_value == nullptr) {
389 return false;
390 }
391 Status s = AttrValueHasType(*attr_value, "tensor");
392 if (!s.ok()) {
393 return false;
394 }
395 *value = &attr_value->tensor();
396 return true;
397 }
398
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)399 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
400 const NameAttrList** value) {
401 const AttrValue* attr_value;
402 TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
403 TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
404 *value = &attr_value->func();
405 return Status::OK();
406 }
407
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)408 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
409 const NameAttrList** value) {
410 const AttrValue* attr_value = attrs.Find(attr_name);
411 if (attr_value == nullptr) {
412 return false;
413 }
414 Status s = AttrValueHasType(*attr_value, "func");
415 if (!s.ok()) {
416 return false;
417 }
418 *value = &attr_value->func();
419 return true;
420 }
421
422 namespace { // Helper for InOutTypesForNode().
423
424 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)425 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
426 const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
427 const int original_size = sig->size();
428 if (!arg_def.number_attr().empty()) {
429 // Same type repeated "repeats" times.
430 int32 repeats = -1;
431 TF_RETURN_IF_ERROR(
432 GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
433 if (repeats < 0) {
434 return errors::InvalidArgument("Value for number_attr() ", repeats,
435 " < 0");
436 }
437
438 if (!arg_def.type_attr().empty()) {
439 DataType dtype;
440 TF_RETURN_IF_ERROR(
441 GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
442 for (int i = 0; i < repeats; ++i) {
443 sig->push_back(dtype);
444 }
445 } else if (arg_def.type() != DT_INVALID) {
446 for (int i = 0; i < repeats; ++i) {
447 sig->push_back(arg_def.type());
448 }
449 } else {
450 return errors::InvalidArgument("Missing type or type_attr field in ",
451 arg_def.ShortDebugString());
452 }
453 } else if (!arg_def.type_attr().empty()) {
454 const AttrValue* attr_value;
455 TF_RETURN_IF_ERROR(
456 AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
457 sig->push_back(attr_value->type());
458 } else if (!arg_def.type_list_attr().empty()) {
459 const AttrValue* attr_value;
460 TF_RETURN_IF_ERROR(
461 AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
462 for (int dtype : attr_value->list().type()) {
463 sig->push_back(static_cast<DataType>(dtype));
464 }
465 } else if (arg_def.type() != DT_INVALID) {
466 sig->push_back(arg_def.type());
467 } else {
468 return errors::InvalidArgument("No type fields in ",
469 arg_def.ShortDebugString());
470 }
471 if (arg_def.is_ref()) {
472 // For all types that were added by this function call, make them refs.
473 for (size_t i = original_size; i < sig->size(); ++i) {
474 (*sig)[i] = MakeRefType((*sig)[i]);
475 }
476 }
477 return Status::OK();
478 }
479
480 } // namespace
481
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)482 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
483 int input_port, DataType* input_type) {
484 DataTypeVector input_types;
485 for (const auto& arg : op_def.input_arg()) {
486 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
487 if (input_types.size() > input_port) {
488 const DataType dtype = input_types[input_port];
489 *input_type = dtype;
490 return Status::OK();
491 }
492 }
493 return errors::InvalidArgument("Input ", input_port, " not found for node ",
494 node_def.name());
495 }
496
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)497 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
498 DataTypeVector* inputs) {
499 for (const auto& arg : op_def.input_arg()) {
500 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
501 }
502 return Status::OK();
503 }
504
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)505 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
506 int output_port, DataType* output_type) {
507 DataTypeVector output_types;
508 for (const auto& arg : op_def.output_arg()) {
509 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
510 if (output_types.size() > output_port) {
511 const DataType dtype = output_types[output_port];
512 *output_type = dtype;
513 return Status::OK();
514 }
515 }
516 return errors::InvalidArgument("Output ", output_port, " not found for node ",
517 node_def.name());
518 }
519
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)520 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
521 DataTypeVector* outputs) {
522 for (const auto& arg : op_def.output_arg()) {
523 TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
524 }
525 return Status::OK();
526 }
527
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)528 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
529 DataTypeVector* outputs) {
530 for (const auto& arg : op_def.output_arg()) {
531 TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
532 }
533 return Status::OK();
534 }
535
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)536 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
537 DataTypeVector* inputs, DataTypeVector* outputs) {
538 TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
539 return OutputTypesForNode(node_def, op_def, outputs);
540 }
541
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)542 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
543 int* num_outputs) {
544 DataTypeVector outputs;
545 TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
546 *num_outputs = outputs.size();
547 return Status::OK();
548 }
549
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)550 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
551 if (node_def.op() != op_def.name()) {
552 return errors::InvalidArgument(
553 "NodeDef op '", node_def.op(), "' does not match ",
554 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
555 }
556
557 bool seen_control = false;
558 size_t num_inputs = 0;
559 // TODO(josh11b): Unify the input field validation.
560 for (const string& input : node_def.input()) {
561 if (absl::StartsWith(input, "^")) {
562 seen_control = true;
563 if (input.find(':') != string::npos) {
564 return errors::InvalidArgument("Control input '", input,
565 "' must not have ':' in NodeDef: ",
566 FormatNodeDefForError(node_def));
567 }
568 } else if (seen_control) {
569 return errors::InvalidArgument("Non-control input '", input,
570 "' after control input in NodeDef: ",
571 FormatNodeDefForError(node_def));
572 } else {
573 ++num_inputs;
574 }
575 }
576
577 std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
578 for (const auto& attr : op_def.attr()) {
579 if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
580 return errors::InvalidArgument("OpDef has duplicate attr name '",
581 attr.name(),
582 "': ", SummarizeOpDef(op_def));
583 }
584 }
585 for (const auto& attr : node_def.attr()) {
586 // Allow internal optional attributes with names starting with "_".
587 if (absl::StartsWith(attr.first, "_")) {
588 continue;
589 }
590 auto iter = op_attrs.find(attr.first);
591 if (iter == op_attrs.end()) {
592 // A common cause of this error is that TensorFlow has made a
593 // backwards-compatible change to the NodeDef (e.g., adding a
594 // new attr with a default value), but the binary consuming the
595 // NodeDef does not know about the new attribute; the solution
596 // in these cases is to ensure that the binary consuming the
597 // NodeDef is built with a version of TensorFlow no earlier than
598 // the binary producing it.
599 return errors::InvalidArgument(
600 "NodeDef mentions attr '", attr.first, "' not in ",
601 SummarizeOpDef(op_def),
602 "; NodeDef: ", FormatNodeDefForError(node_def),
603 ". (Check whether your GraphDef-interpreting binary is up to date "
604 "with your GraphDef-generating binary.).");
605 }
606 // If attr value is placeholder, do not check it.
607 if (attr.second.placeholder().empty()) {
608 TF_RETURN_WITH_CONTEXT_IF_ERROR(
609 ValidateAttrValue(attr.second, *iter->second),
610 "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
611 SummarizeOpDef(op_def));
612 }
613 // Keep track of which attr names have (not) been found in the NodeDef.
614 op_attrs.erase(iter);
615 }
616
617 // Were all attrs in the OpDef found in the NodeDef?
618 if (!op_attrs.empty()) {
619 string attrs;
620 for (const auto& attr_pair : op_attrs) {
621 if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
622 strings::StrAppend(&attrs, attr_pair.first);
623 }
624 return errors::InvalidArgument(
625 "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
626 "' from ", SummarizeOpDef(op_def),
627 "; NodeDef: ", FormatNodeDefForError(node_def));
628 }
629
630 // Validate the number of inputs.
631 DataTypeVector inputs, outputs;
632 TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
633
634 if (num_inputs != inputs.size()) {
635 return errors::InvalidArgument(
636 "NodeDef expected inputs '", DataTypeVectorString(inputs),
637 "' do not match ", num_inputs, " inputs specified; ",
638 SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
639 }
640
641 return Status::OK();
642 }
643
644 namespace { // Helpers for NameRangesForNode()
645
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)646 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
647 const OpDef& op_def, int* num) {
648 if (!arg_def.number_attr().empty()) {
649 // Same type repeated "num" times.
650 return GetNodeAttr(attrs, arg_def.number_attr(), num);
651 } else if (!arg_def.type_list_attr().empty()) {
652 const AttrValue* attr_value;
653 TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
654 *num = attr_value->list().type_size();
655 } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
656 *num = 1;
657 } else {
658 return errors::InvalidArgument(
659 "Argument '", arg_def.name(),
660 "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
661 }
662 return Status::OK();
663 }
664
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)665 Status NameRangesHelper(const AttrSlice& attrs,
666 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
667 const OpDef& op_def, NameRangeMap* result) {
668 int start = 0;
669 int num;
670 for (const auto& arg : args) {
671 TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
672 (*result)[arg.name()] = std::make_pair(start, start + num);
673 start += num;
674 }
675 return Status::OK();
676 }
677
678 } // namespace
679
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)680 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
681 NameRangeMap* inputs, NameRangeMap* outputs) {
682 if (inputs != nullptr) {
683 TF_RETURN_IF_ERROR(
684 NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
685 }
686 if (outputs != nullptr) {
687 return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
688 }
689 return Status::OK();
690 }
691
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)692 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
693 for (const auto& attr_def : op_def.attr()) {
694 AttrSlice attrs(*node_def);
695 if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
696 AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
697 }
698 }
699 }
700
701 namespace {
702
703 using ::tensorflow::tstring;
704 using ::tensorflow::strings::Scanner;
705
IsValidNodeName(StringPiece sp)706 bool IsValidNodeName(StringPiece sp) {
707 Scanner scanner(sp);
708 scanner.One(Scanner::LETTER_DIGIT_DOT)
709 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
710
711 while (true) {
712 if (!scanner.GetResult()) // Some error in previous iteration.
713 return false;
714 if (scanner.empty()) // No error, but nothing left, good.
715 return true;
716
717 // Absorb another name/namespace, starting with a '>'
718 scanner.One(Scanner::RANGLE)
719 .One(Scanner::LETTER_DIGIT_DOT)
720 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
721 }
722 }
723
IsValidDataInputName(StringPiece sp)724 bool IsValidDataInputName(StringPiece sp) {
725 // Data inputs are op_name, op_name:0, or op_name:12345.
726 Scanner scan(sp);
727 scan.One(Scanner::LETTER_DIGIT_DOT)
728 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
729
730 while (true) {
731 if (!scan.GetResult()) // Some error in previous iteration.
732 return false;
733 if (scan.empty()) // No error, but nothing left, good.
734 return true;
735
736 if (scan.Peek() == ':') { // Absorb identifier after the colon
737 scan.OneLiteral(":");
738 if (scan.Peek() == '0') {
739 scan.OneLiteral("0"); // :0
740 } else {
741 scan.Many(Scanner::DIGIT); // :[1-9][0-9]*
742 }
743 } else {
744 // Absorb another name/namespace, starting with a '>'
745 scan.One(Scanner::RANGLE)
746 .One(Scanner::LETTER_DIGIT_DOT)
747 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
748 }
749 }
750 }
751
IsValidControlInputName(StringPiece sp)752 bool IsValidControlInputName(StringPiece sp) {
753 Scanner scan(sp);
754 scan.OneLiteral("^")
755 .One(Scanner::LETTER_DIGIT_DOT)
756 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
757
758 while (true) {
759 if (!scan.GetResult()) // Some error in previous iteration.
760 return false;
761 if (scan.empty()) // No error, but nothing left, good.
762 return true;
763
764 // Absorb another name/namespace, starting with a '>'
765 scan.One(Scanner::RANGLE)
766 .One(Scanner::LETTER_DIGIT_DOT)
767 .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
768 }
769 }
770
771 } // namespace
772
ValidateOpInput(const string & input_name,bool * is_control_input)773 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
774 *is_control_input = false;
775 if (IsValidDataInputName(input_name)) {
776 return Status::OK();
777 } else if (IsValidControlInputName(input_name)) {
778 *is_control_input = true;
779 return Status::OK();
780 } else {
781 return errors::InvalidArgument("Illegal op input name '", input_name, "'");
782 }
783 }
784
ValidateNodeName(const string & node_name)785 Status ValidateNodeName(const string& node_name) {
786 if (IsValidNodeName(node_name)) {
787 return Status::OK();
788 } else {
789 return errors::InvalidArgument("Illegal op name '", node_name, "'");
790 }
791 }
792
ValidateExternalNodeDefSyntax(const NodeDef & node_def)793 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
794 Status s = ValidateNodeName(node_def.name());
795 if (!s.ok()) {
796 return AttachDef(s, node_def);
797 }
798 bool in_control_inputs = false;
799 for (const string& input_name : node_def.input()) {
800 bool is_control_input;
801 s = ValidateOpInput(input_name, &is_control_input);
802 if (!s.ok()) {
803 return AttachDef(s, node_def);
804 }
805
806 if (in_control_inputs && !is_control_input) {
807 return AttachDef(errors::InvalidArgument(
808 "All control inputs must follow all data inputs"),
809 node_def);
810 }
811 in_control_inputs = is_control_input;
812 }
813 return Status::OK();
814 }
815
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)816 Status AttachDef(const Status& status, const NodeDef& node_def,
817 bool allow_multiple_formatted_node) {
818 Status ret = status;
819 string node_error;
820 if (!allow_multiple_formatted_node &&
821 status.error_message().find("{{node ") != string::npos) {
822 node_error = node_def.name();
823 } else {
824 node_error = FormatNodeDefForError(node_def);
825 }
826 errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
827 return ret;
828 }
829
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)830 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
831 node_def->mutable_attr()->insert(
832 AttrValueMap::value_type(string(name), value));
833 }
834
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)835 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
836 (*node_def->mutable_attr())[string(name)] = std::move(value);
837 }
838
839 #define ADD_NODE_ATTR(T) \
840 void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
841 AttrValue attr_value; \
842 SetAttrValue(value, &attr_value); \
843 AddNodeAttr(name, attr_value, node_def); \
844 }
845 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)846 ADD_NODE_ATTR(const char*)
847 ADD_NODE_ATTR(int32)
848 ADD_NODE_ATTR(int64)
849 ADD_NODE_ATTR(float)
850 ADD_NODE_ATTR(double)
851 ADD_NODE_ATTR(bool)
852 ADD_NODE_ATTR(DataType)
853 ADD_NODE_ATTR(const PartialTensorShape&)
854 ADD_NODE_ATTR(const Tensor&)
855 ADD_NODE_ATTR(const TensorProto&)
856 ADD_NODE_ATTR(const NameAttrList&)
857 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
858 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
859 ADD_NODE_ATTR(gtl::ArraySlice<string>)
860 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
861 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
862 ADD_NODE_ATTR(gtl::ArraySlice<float>)
863 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
864 ADD_NODE_ATTR(const std::vector<bool>&)
865 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
866 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
867 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
868 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
869 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
870 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
871 #undef ADD_NODE_ATTR
872
873 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
874 map->insert(AttrValueMap::value_type(string(name), value));
875 }
876
877 #define ADD_ATTR(T) \
878 void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
879 AttrValue attr_value; \
880 SetAttrValue(value, &attr_value); \
881 AddAttr(name, attr_value, map); \
882 }
ADD_ATTR(bool)883 ADD_ATTR(bool)
884 #undef ADD_ATTR
885
886 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
887 NodeDef* node_def, bool uniquify_frame_name) {
888 node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
889
890 // Update frame name to avoid multiple LoopCond nodes in one frame.
891 if (uniquify_frame_name &&
892 (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
893 string frame_name;
894 TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
895 AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
896 frame_name = strings::StrCat(prefix, frame_name, suffix);
897 attr.set_s(frame_name);
898 }
899
900 // Update colocation constraints.
901 constexpr char kClassAttr[] = "_class";
902 auto class_attr = node_def->mutable_attr()->find(kClassAttr);
903 if (class_attr != node_def->mutable_attr()->end()) {
904 AttrValue new_value;
905 new_value.mutable_list()->add_s(
906 strings::StrCat(prefix, class_attr->second.s()));
907 node_def->mutable_attr()->erase(kClassAttr);
908 node_def->mutable_attr()->insert({kClassAttr, new_value});
909 }
910
911 return Status::OK();
912 }
913
914 } // namespace tensorflow
915