• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47 
AttrSlice()48 AttrSlice::AttrSlice() : ndef_(nullptr) {
49   static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
50   attrs_ = kEmptyAttrValueMap;
51 }
52 
AttrSlice(const NodeDef & node_def)53 AttrSlice::AttrSlice(const NodeDef& node_def)
54     : ndef_(&node_def), attrs_(&ndef_->attr()) {}
55 
AttrSlice(const AttrValueMap * a)56 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
57 
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)58 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
59   string ret;
60 
61   // We sort the attrs so the output is deterministic.
62   std::vector<string> attr_names;
63   attr_names.reserve(attrs.size());
64   for (const auto& attr : attrs) {
65     attr_names.push_back(attr.first);
66   }
67   std::sort(attr_names.begin(), attr_names.end());
68   bool first = true;
69   for (const string& attr_name : attr_names) {
70     if (!first) strings::StrAppend(&ret, ", ");
71     first = false;
72     strings::StrAppend(&ret, attr_name, "=",
73                        SummarizeAttrValue(*attrs.Find(attr_name)));
74   }
75 
76   // Consider the device to be a final attr with name "_device".
77   if (!device.empty()) {
78     if (!first) strings::StrAppend(&ret, ", ");
79     first = false;
80     strings::StrAppend(&ret, "_device=\"", device, "\"");
81   }
82   return ret;
83 }
84 
SummarizeNode() const85 string AttrSlice::SummarizeNode() const {
86   return ndef_ ? SummarizeNodeDef(*ndef_)
87                : strings::StrCat(
88                      "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
89 }
90 
DebugString() const91 string AttrSlice::DebugString() const {
92   std::vector<string> attr_key_vals;
93   attr_key_vals.reserve(attrs_->size());
94   for (const auto& it : *this) {
95     const string& name = it.first;
96     const AttrValue& attr_value = it.second;
97     attr_key_vals.push_back(
98         absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
99   }
100   return absl::StrJoin(attr_key_vals, ", ");
101 }
102 
SummarizeNodeDef(const NodeDef & node_def)103 string SummarizeNodeDef(const NodeDef& node_def) {
104   string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
105                                " = ", node_def.op(), "[");
106   strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
107   strings::StrAppend(&ret, "](");
108 
109   // Output inputs, including control inputs, verbatim.
110   bool first = true;
111   for (const string& input : node_def.input()) {
112     if (!first) strings::StrAppend(&ret, ", ");
113     first = false;
114     strings::StrAppend(&ret, input);
115   }
116   strings::StrAppend(&ret, ")");
117   return ret;
118 }
119 
SummarizeAttrs(const NodeDef & node_def)120 string SummarizeAttrs(const NodeDef& node_def) {
121   return SummarizeAttrsHelper(node_def, node_def.device());
122 }
123 
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)124 string FormatNodeDefForError(
125     StringPiece node_name, bool has_experimental_debug_info,
126     const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
127   return !has_experimental_debug_info ||
128                  experimental_debug_info.original_node_names().empty()
129              ? errors::FormatNodeNameForError(string(node_name))
130              : errors::FormatNodeNamesForError(
131                    experimental_debug_info.original_node_names());
132 }
133 
FormatNodeDefForError(const NodeDef & node_def)134 string FormatNodeDefForError(const NodeDef& node_def) {
135   return FormatNodeDefForError(node_def.name(),
136                                node_def.has_experimental_debug_info(),
137                                node_def.experimental_debug_info());
138 }
139 
Find(StringPiece attr_name) const140 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
141   // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
142   // requires that the keys used for lookups have type 'const string&'. Because
143   // this method takes a StringPiece, it is necessary to allocate a temporary
144   // string, copy attr_name to it, and then use that temporary string for the
145   // lookup. This causes an excessive number of short-lived allocations, and for
146   // large graphs, this can be a significant cost.
147   //
148   // Because most nodes have a small number of attributes, a simple linear scan
149   // is generally more efficient than a hashed lookup.  If google::protobuf::Map
150   // changes so that it supports efficient lookups using StringPiece instead of
151   // const string&, then this code could be changed to use attrs_->find() again.
152 
153   for (const auto& attr : *attrs_) {
154     if (attr.first == attr_name) {
155       return &attr.second;
156     }
157   }
158   return nullptr;
159 }
160 
Find(StringPiece attr_name,const AttrValue ** attr_value) const161 Status AttrSlice::Find(StringPiece attr_name,
162                        const AttrValue** attr_value) const {
163   *attr_value = Find(attr_name);
164   if (*attr_value != nullptr) {
165     return Status::OK();
166   }
167   Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
168   // Skip AttachDef for internal attrs since it is a little bit
169   // expensive and it is common for them to correctly not be included
170   // in a NodeDef.
171   if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
172     s = AttachDef(s, *ndef_);
173   }
174   return s;
175 }
176 
EqualAttrs(AttrSlice other,Scratch * scratch) const177 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
178   if (size() != other.size()) return false;
179 
180   for (const auto& attr : *other.attrs_) {
181     auto iter = attrs_->find(attr.first);
182     if (iter == attrs_->end()) return false;
183     // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
184     // TensorProto is a nonunique representation of Tensor.  This bug will go
185     // away once AttrSlice switches over to NodeInfo.
186     iter->second.SerializeToString(&scratch->a);
187     attr.second.SerializeToString(&scratch->b);
188     if (scratch->a != scratch->b) return false;
189   }
190   return true;
191 }
192 
193 // The ... is to allow the caller to inject some value validation code.  Use
194 // just ; if no additional validation code is needed.
195 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...)         \
196   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
197                      TYPE* value) {                                           \
198     const AttrValue* attr_value;                                              \
199     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
200     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE));             \
201     const auto& v = attr_value->FIELD();                                      \
202     __VA_ARGS__;                                                              \
203     *value = CAST;                                                            \
204     return Status::OK();                                                      \
205   }                                                                           \
206   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
207                      std::vector<TYPE>* value) {                              \
208     const AttrValue* attr_value;                                              \
209     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
210     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
211     value->reserve(attr_value->list().FIELD().size());                        \
212     for (const auto& v : attr_value->list().FIELD()) {                        \
213       __VA_ARGS__;                                                            \
214       value->APPEND_OP(CAST);                                                 \
215     }                                                                         \
216     return Status::OK();                                                      \
217   }
218 
219 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
220   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
221                       TYPE* value) {                                      \
222     const AttrValue* attr_value = attrs.Find(attr_name);                  \
223     if (attr_value == nullptr) {                                          \
224       return false;                                                       \
225     }                                                                     \
226     Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                  \
227     if (!s.ok()) {                                                        \
228       return false;                                                       \
229     }                                                                     \
230     const auto& v = attr_value->FIELD();                                  \
231     __VA_ARGS__;                                                          \
232     *value = CAST;                                                        \
233     return true;                                                          \
234   }                                                                       \
235   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
236                       std::vector<TYPE>* value) {                         \
237     const AttrValue* attr_value = attrs.Find(attr_name);                  \
238     if (attr_value == nullptr) {                                          \
239       return false;                                                       \
240     }                                                                     \
241     Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");      \
242     if (!s.ok()) {                                                        \
243       return false;                                                       \
244     }                                                                     \
245     value->reserve(attr_value->list().FIELD().size());                    \
246     for (const auto& v : attr_value->list().FIELD()) {                    \
247       __VA_ARGS__;                                                        \
248       value->APPEND_OP(CAST);                                             \
249     }                                                                     \
250     return true;                                                          \
251   }
252 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
253 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
254 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
255 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
256 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
257 DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
258 DEFINE_GET_ATTR(
259     int32, i, "int", emplace_back, static_cast<int32>(v),
260     if (static_cast<int64>(static_cast<int32>(v)) != v) {
261       return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
262                                      " out of range for an int32");
263     })
264 DEFINE_TRY_GET_ATTR(
265     int32, i, "int", emplace_back, static_cast<int32>(v),
266     if (static_cast<int64>(static_cast<int32>(v)) != v) {
267       static int log_counter = 0;
268       if (log_counter < 10) {
269         log_counter++;
270         LOG(WARNING) << "Attr " << attr_name << " has value " << v
271                      << " out of range for an int32";
272       }
273       return false;
274     })
275 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
276 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
277 // std::vector<bool> specialization does not have emplace_back until
278 // c++14, so we have to use push_back (see
279 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
280 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
281 DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
282 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
283                 ;)
284 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
285                     static_cast<DataType>(v),
286                     ;)
287 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
288 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
289                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
290 DEFINE_TRY_GET_ATTR(
291     TensorShape, shape, "shape", emplace_back, TensorShape(v),
292     if (!TensorShape::IsValidShape(v).ok()) {
293       static int log_counter = 0;
294       if (log_counter < 10) {
295         log_counter++;
296         LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
297                      << v.DebugString();
298       }
299       return false;
300     })
301 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
302                 PartialTensorShape(v),
303                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
304 DEFINE_GET_ATTR(
305     Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
306       return errors::InvalidArgument("Attr ", attr_name, " has value ",
307                                      v.ShortDebugString(),
308                                      " that can't be converted to a Tensor");
309     })
310 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
311 #undef DEFINE_GET_ATTR
312 
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)313 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
314   return node_def.attr().find(string(attr_name)) != node_def.attr().end();
315 }
316 
317 static const string& kEmptyString = *new string();
318 
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)319 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
320   const AttrValue* attr_value = attrs.Find(attr_name);
321   if (attr_value == nullptr) {
322     return kEmptyString;
323   }
324   Status s = AttrValueHasType(*attr_value, "string");
325   if (!s.ok()) {
326     return kEmptyString;
327   }
328   return attr_value->s();
329 }
330 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)331 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
332                     std::vector<const string*>* value) {
333   const AttrValue* attr_value = attrs.Find(attr_name);
334   if (attr_value == nullptr) {
335     return false;
336   }
337   Status s = AttrValueHasType(*attr_value, "list(string)");
338   if (!s.ok()) {
339     return false;
340   }
341   value->reserve(attr_value->list().s().size());
342   for (const auto& v : attr_value->list().s()) {
343     value->push_back(&v);
344   }
345   return true;
346 }
347 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)348 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
349                     std::vector<const TensorShapeProto*>* value) {
350   const AttrValue* attr_value = attrs.Find(attr_name);
351   if (attr_value == nullptr) {
352     return false;
353   }
354   Status s = AttrValueHasType(*attr_value, "list(shape)");
355   if (!s.ok()) {
356     return false;
357   }
358   value->reserve(attr_value->list().shape().size());
359   for (const auto& v : attr_value->list().shape()) {
360     value->push_back(&v);
361   }
362   return true;
363 }
364 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)365 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
366                    DataTypeVector* value) {
367   const AttrValue* attr_value;
368   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
369   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
370   for (const auto& v : attr_value->list().type()) {
371     value->push_back(static_cast<DataType>(v));
372   }
373   return Status::OK();
374 }
375 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)376 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
377                    const TensorProto** value) {
378   const AttrValue* attr_value;
379   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
380   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
381   *value = &attr_value->tensor();
382   return Status::OK();
383 }
384 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)385 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
386                     const TensorProto** value) {
387   const AttrValue* attr_value = attrs.Find(attr_name);
388   if (attr_value == nullptr) {
389     return false;
390   }
391   Status s = AttrValueHasType(*attr_value, "tensor");
392   if (!s.ok()) {
393     return false;
394   }
395   *value = &attr_value->tensor();
396   return true;
397 }
398 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)399 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
400                    const NameAttrList** value) {
401   const AttrValue* attr_value;
402   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
403   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
404   *value = &attr_value->func();
405   return Status::OK();
406 }
407 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)408 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
409                     const NameAttrList** value) {
410   const AttrValue* attr_value = attrs.Find(attr_name);
411   if (attr_value == nullptr) {
412     return false;
413   }
414   Status s = AttrValueHasType(*attr_value, "func");
415   if (!s.ok()) {
416     return false;
417   }
418   *value = &attr_value->func();
419   return true;
420 }
421 
422 namespace {  // Helper for InOutTypesForNode().
423 
424 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)425 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
426                    const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
427   const int original_size = sig->size();
428   if (!arg_def.number_attr().empty()) {
429     // Same type repeated "repeats" times.
430     int32 repeats = -1;
431     TF_RETURN_IF_ERROR(
432         GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
433     if (repeats < 0) {
434       return errors::InvalidArgument("Value for number_attr() ", repeats,
435                                      " < 0");
436     }
437 
438     if (!arg_def.type_attr().empty()) {
439       DataType dtype;
440       TF_RETURN_IF_ERROR(
441           GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
442       for (int i = 0; i < repeats; ++i) {
443         sig->push_back(dtype);
444       }
445     } else if (arg_def.type() != DT_INVALID) {
446       for (int i = 0; i < repeats; ++i) {
447         sig->push_back(arg_def.type());
448       }
449     } else {
450       return errors::InvalidArgument("Missing type or type_attr field in ",
451                                      arg_def.ShortDebugString());
452     }
453   } else if (!arg_def.type_attr().empty()) {
454     const AttrValue* attr_value;
455     TF_RETURN_IF_ERROR(
456         AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
457     sig->push_back(attr_value->type());
458   } else if (!arg_def.type_list_attr().empty()) {
459     const AttrValue* attr_value;
460     TF_RETURN_IF_ERROR(
461         AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
462     for (int dtype : attr_value->list().type()) {
463       sig->push_back(static_cast<DataType>(dtype));
464     }
465   } else if (arg_def.type() != DT_INVALID) {
466     sig->push_back(arg_def.type());
467   } else {
468     return errors::InvalidArgument("No type fields in ",
469                                    arg_def.ShortDebugString());
470   }
471   if (arg_def.is_ref()) {
472     // For all types that were added by this function call, make them refs.
473     for (size_t i = original_size; i < sig->size(); ++i) {
474       (*sig)[i] = MakeRefType((*sig)[i]);
475     }
476   }
477   return Status::OK();
478 }
479 
480 }  // namespace
481 
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)482 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
483                         int input_port, DataType* input_type) {
484   DataTypeVector input_types;
485   for (const auto& arg : op_def.input_arg()) {
486     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
487     if (input_types.size() > input_port) {
488       const DataType dtype = input_types[input_port];
489       *input_type = dtype;
490       return Status::OK();
491     }
492   }
493   return errors::InvalidArgument("Input ", input_port, " not found for node ",
494                                  node_def.name());
495 }
496 
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)497 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
498                          DataTypeVector* inputs) {
499   for (const auto& arg : op_def.input_arg()) {
500     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
501   }
502   return Status::OK();
503 }
504 
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)505 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
506                          int output_port, DataType* output_type) {
507   DataTypeVector output_types;
508   for (const auto& arg : op_def.output_arg()) {
509     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
510     if (output_types.size() > output_port) {
511       const DataType dtype = output_types[output_port];
512       *output_type = dtype;
513       return Status::OK();
514     }
515   }
516   return errors::InvalidArgument("Output ", output_port, " not found for node ",
517                                  node_def.name());
518 }
519 
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)520 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
521                           DataTypeVector* outputs) {
522   for (const auto& arg : op_def.output_arg()) {
523     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
524   }
525   return Status::OK();
526 }
527 
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)528 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
529                           DataTypeVector* outputs) {
530   for (const auto& arg : op_def.output_arg()) {
531     TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
532   }
533   return Status::OK();
534 }
535 
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)536 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
537                          DataTypeVector* inputs, DataTypeVector* outputs) {
538   TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
539   return OutputTypesForNode(node_def, op_def, outputs);
540 }
541 
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)542 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
543                          int* num_outputs) {
544   DataTypeVector outputs;
545   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
546   *num_outputs = outputs.size();
547   return Status::OK();
548 }
549 
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)550 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
551   if (node_def.op() != op_def.name()) {
552     return errors::InvalidArgument(
553         "NodeDef op '", node_def.op(), "' does not match ",
554         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
555   }
556 
557   bool seen_control = false;
558   size_t num_inputs = 0;
559   // TODO(josh11b): Unify the input field validation.
560   for (const string& input : node_def.input()) {
561     if (absl::StartsWith(input, "^")) {
562       seen_control = true;
563       if (input.find(':') != string::npos) {
564         return errors::InvalidArgument("Control input '", input,
565                                        "' must not have ':' in NodeDef: ",
566                                        FormatNodeDefForError(node_def));
567       }
568     } else if (seen_control) {
569       return errors::InvalidArgument("Non-control input '", input,
570                                      "' after control input in NodeDef: ",
571                                      FormatNodeDefForError(node_def));
572     } else {
573       ++num_inputs;
574     }
575   }
576 
577   std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
578   for (const auto& attr : op_def.attr()) {
579     if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
580       return errors::InvalidArgument("OpDef has duplicate attr name '",
581                                      attr.name(),
582                                      "': ", SummarizeOpDef(op_def));
583     }
584   }
585   for (const auto& attr : node_def.attr()) {
586     // Allow internal optional attributes with names starting with "_".
587     if (absl::StartsWith(attr.first, "_")) {
588       continue;
589     }
590     auto iter = op_attrs.find(attr.first);
591     if (iter == op_attrs.end()) {
592       // A common cause of this error is that TensorFlow has made a
593       // backwards-compatible change to the NodeDef (e.g., adding a
594       // new attr with a default value), but the binary consuming the
595       // NodeDef does not know about the new attribute; the solution
596       // in these cases is to ensure that the binary consuming the
597       // NodeDef is built with a version of TensorFlow no earlier than
598       // the binary producing it.
599       return errors::InvalidArgument(
600           "NodeDef mentions attr '", attr.first, "' not in ",
601           SummarizeOpDef(op_def),
602           "; NodeDef: ", FormatNodeDefForError(node_def),
603           ". (Check whether your GraphDef-interpreting binary is up to date "
604           "with your GraphDef-generating binary.).");
605     }
606     // If attr value is placeholder, do not check it.
607     if (attr.second.placeholder().empty()) {
608       TF_RETURN_WITH_CONTEXT_IF_ERROR(
609           ValidateAttrValue(attr.second, *iter->second),
610           "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
611           SummarizeOpDef(op_def));
612     }
613     // Keep track of which attr names have (not) been found in the NodeDef.
614     op_attrs.erase(iter);
615   }
616 
617   // Were all attrs in the OpDef found in the NodeDef?
618   if (!op_attrs.empty()) {
619     string attrs;
620     for (const auto& attr_pair : op_attrs) {
621       if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
622       strings::StrAppend(&attrs, attr_pair.first);
623     }
624     return errors::InvalidArgument(
625         "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
626         "' from ", SummarizeOpDef(op_def),
627         "; NodeDef: ", FormatNodeDefForError(node_def));
628   }
629 
630   // Validate the number of inputs.
631   DataTypeVector inputs, outputs;
632   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
633 
634   if (num_inputs != inputs.size()) {
635     return errors::InvalidArgument(
636         "NodeDef expected inputs '", DataTypeVectorString(inputs),
637         "' do not match ", num_inputs, " inputs specified; ",
638         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
639   }
640 
641   return Status::OK();
642 }
643 
644 namespace {  // Helpers for NameRangesForNode()
645 
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)646 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
647                        const OpDef& op_def, int* num) {
648   if (!arg_def.number_attr().empty()) {
649     // Same type repeated "num" times.
650     return GetNodeAttr(attrs, arg_def.number_attr(), num);
651   } else if (!arg_def.type_list_attr().empty()) {
652     const AttrValue* attr_value;
653     TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
654     *num = attr_value->list().type_size();
655   } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
656     *num = 1;
657   } else {
658     return errors::InvalidArgument(
659         "Argument '", arg_def.name(),
660         "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
661   }
662   return Status::OK();
663 }
664 
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)665 Status NameRangesHelper(const AttrSlice& attrs,
666                         const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
667                         const OpDef& op_def, NameRangeMap* result) {
668   int start = 0;
669   int num;
670   for (const auto& arg : args) {
671     TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
672     (*result)[arg.name()] = std::make_pair(start, start + num);
673     start += num;
674   }
675   return Status::OK();
676 }
677 
678 }  // namespace
679 
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)680 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
681                          NameRangeMap* inputs, NameRangeMap* outputs) {
682   if (inputs != nullptr) {
683     TF_RETURN_IF_ERROR(
684         NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
685   }
686   if (outputs != nullptr) {
687     return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
688   }
689   return Status::OK();
690 }
691 
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)692 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
693   for (const auto& attr_def : op_def.attr()) {
694     AttrSlice attrs(*node_def);
695     if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
696       AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
697     }
698   }
699 }
700 
701 namespace {
702 
703 using ::tensorflow::tstring;
704 using ::tensorflow::strings::Scanner;
705 
IsValidNodeName(StringPiece sp)706 bool IsValidNodeName(StringPiece sp) {
707   Scanner scanner(sp);
708   scanner.One(Scanner::LETTER_DIGIT_DOT)
709       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
710 
711   while (true) {
712     if (!scanner.GetResult())  // Some error in previous iteration.
713       return false;
714     if (scanner.empty())  // No error, but nothing left, good.
715       return true;
716 
717     // Absorb another name/namespace, starting with a '>'
718     scanner.One(Scanner::RANGLE)
719         .One(Scanner::LETTER_DIGIT_DOT)
720         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
721   }
722 }
723 
IsValidDataInputName(StringPiece sp)724 bool IsValidDataInputName(StringPiece sp) {
725   // Data inputs are op_name, op_name:0, or op_name:12345.
726   Scanner scan(sp);
727   scan.One(Scanner::LETTER_DIGIT_DOT)
728       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
729 
730   while (true) {
731     if (!scan.GetResult())  // Some error in previous iteration.
732       return false;
733     if (scan.empty())  // No error, but nothing left, good.
734       return true;
735 
736     if (scan.Peek() == ':') {  // Absorb identifier after the colon
737       scan.OneLiteral(":");
738       if (scan.Peek() == '0') {
739         scan.OneLiteral("0");  // :0
740       } else {
741         scan.Many(Scanner::DIGIT);  // :[1-9][0-9]*
742       }
743     } else {
744       // Absorb another name/namespace, starting with a '>'
745       scan.One(Scanner::RANGLE)
746           .One(Scanner::LETTER_DIGIT_DOT)
747           .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
748     }
749   }
750 }
751 
IsValidControlInputName(StringPiece sp)752 bool IsValidControlInputName(StringPiece sp) {
753   Scanner scan(sp);
754   scan.OneLiteral("^")
755       .One(Scanner::LETTER_DIGIT_DOT)
756       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
757 
758   while (true) {
759     if (!scan.GetResult())  // Some error in previous iteration.
760       return false;
761     if (scan.empty())  // No error, but nothing left, good.
762       return true;
763 
764     // Absorb another name/namespace, starting with a '>'
765     scan.One(Scanner::RANGLE)
766         .One(Scanner::LETTER_DIGIT_DOT)
767         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
768   }
769 }
770 
771 }  // namespace
772 
ValidateOpInput(const string & input_name,bool * is_control_input)773 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
774   *is_control_input = false;
775   if (IsValidDataInputName(input_name)) {
776     return Status::OK();
777   } else if (IsValidControlInputName(input_name)) {
778     *is_control_input = true;
779     return Status::OK();
780   } else {
781     return errors::InvalidArgument("Illegal op input name '", input_name, "'");
782   }
783 }
784 
ValidateNodeName(const string & node_name)785 Status ValidateNodeName(const string& node_name) {
786   if (IsValidNodeName(node_name)) {
787     return Status::OK();
788   } else {
789     return errors::InvalidArgument("Illegal op name '", node_name, "'");
790   }
791 }
792 
ValidateExternalNodeDefSyntax(const NodeDef & node_def)793 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
794   Status s = ValidateNodeName(node_def.name());
795   if (!s.ok()) {
796     return AttachDef(s, node_def);
797   }
798   bool in_control_inputs = false;
799   for (const string& input_name : node_def.input()) {
800     bool is_control_input;
801     s = ValidateOpInput(input_name, &is_control_input);
802     if (!s.ok()) {
803       return AttachDef(s, node_def);
804     }
805 
806     if (in_control_inputs && !is_control_input) {
807       return AttachDef(errors::InvalidArgument(
808                            "All control inputs must follow all data inputs"),
809                        node_def);
810     }
811     in_control_inputs = is_control_input;
812   }
813   return Status::OK();
814 }
815 
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)816 Status AttachDef(const Status& status, const NodeDef& node_def,
817                  bool allow_multiple_formatted_node) {
818   Status ret = status;
819   string node_error;
820   if (!allow_multiple_formatted_node &&
821       status.error_message().find("{{node ") != string::npos) {
822     node_error = node_def.name();
823   } else {
824     node_error = FormatNodeDefForError(node_def);
825   }
826   errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
827   return ret;
828 }
829 
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)830 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
831   node_def->mutable_attr()->insert(
832       AttrValueMap::value_type(string(name), value));
833 }
834 
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)835 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
836   (*node_def->mutable_attr())[string(name)] = std::move(value);
837 }
838 
839 #define ADD_NODE_ATTR(T)                                           \
840   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
841     AttrValue attr_value;                                          \
842     SetAttrValue(value, &attr_value);                              \
843     AddNodeAttr(name, attr_value, node_def);                       \
844   }
845 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)846 ADD_NODE_ATTR(const char*)
847 ADD_NODE_ATTR(int32)
848 ADD_NODE_ATTR(int64)
849 ADD_NODE_ATTR(float)
850 ADD_NODE_ATTR(double)
851 ADD_NODE_ATTR(bool)
852 ADD_NODE_ATTR(DataType)
853 ADD_NODE_ATTR(const PartialTensorShape&)
854 ADD_NODE_ATTR(const Tensor&)
855 ADD_NODE_ATTR(const TensorProto&)
856 ADD_NODE_ATTR(const NameAttrList&)
857 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
858 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
859 ADD_NODE_ATTR(gtl::ArraySlice<string>)
860 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
861 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
862 ADD_NODE_ATTR(gtl::ArraySlice<float>)
863 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
864 ADD_NODE_ATTR(const std::vector<bool>&)
865 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
866 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
867 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
868 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
869 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
870 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
871 #undef ADD_NODE_ATTR
872 
873 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
874   map->insert(AttrValueMap::value_type(string(name), value));
875 }
876 
877 #define ADD_ATTR(T)                                            \
878   void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
879     AttrValue attr_value;                                      \
880     SetAttrValue(value, &attr_value);                          \
881     AddAttr(name, attr_value, map);                            \
882   }
ADD_ATTR(bool)883 ADD_ATTR(bool)
884 #undef ADD_ATTR
885 
886 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
887                                 NodeDef* node_def, bool uniquify_frame_name) {
888   node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
889 
890   // Update frame name to avoid multiple LoopCond nodes in one frame.
891   if (uniquify_frame_name &&
892       (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
893     string frame_name;
894     TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
895     AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
896     frame_name = strings::StrCat(prefix, frame_name, suffix);
897     attr.set_s(frame_name);
898   }
899 
900   // Update colocation constraints.
901   constexpr char kClassAttr[] = "_class";
902   auto class_attr = node_def->mutable_attr()->find(kClassAttr);
903   if (class_attr != node_def->mutable_attr()->end()) {
904     AttrValue new_value;
905     new_value.mutable_list()->add_s(
906         strings::StrCat(prefix, class_attr->second.s()));
907     node_def->mutable_attr()->erase(kClassAttr);
908     node_def->mutable_attr()->insert({kClassAttr, new_value});
909   }
910 
911   return Status::OK();
912 }
913 
914 }  // namespace tensorflow
915