• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47 // For TPU distributed rewrite, TPU args are collected and "staged" on the local
48 // host using an IdentityN TF op. Some args may result from a remote source.
49 // When all arg tensors are available, the TPUExecute op can be inovoked. See
50 // DistributedTPURewritePass for more details.
51 const char* const kTpuExecuteStagingOp = "IdentityN";
52 const char* const kTpuExecuteStagingNodeName = "_variable_copy";
53 
AttrSlice()54 AttrSlice::AttrSlice() : ndef_(nullptr) {
55   static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
56   attrs_ = kEmptyAttrValueMap;
57 }
58 
AttrSlice(const NodeDef & node_def)59 AttrSlice::AttrSlice(const NodeDef& node_def)
60     : ndef_(&node_def), attrs_(&ndef_->attr()) {}
61 
AttrSlice(const AttrValueMap * a)62 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
63 
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)64 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
65   string ret;
66 
67   // We sort the attrs so the output is deterministic.
68   std::vector<string> attr_names;
69   attr_names.reserve(attrs.size());
70   for (const auto& attr : attrs) {
71     attr_names.push_back(attr.first);
72   }
73   std::sort(attr_names.begin(), attr_names.end());
74   bool first = true;
75   for (const string& attr_name : attr_names) {
76     if (!first) strings::StrAppend(&ret, ", ");
77     first = false;
78     strings::StrAppend(&ret, attr_name, "=",
79                        SummarizeAttrValue(*attrs.Find(attr_name)));
80   }
81 
82   // Consider the device to be a final attr with name "_device".
83   if (!device.empty()) {
84     if (!first) strings::StrAppend(&ret, ", ");
85     first = false;
86     strings::StrAppend(&ret, "_device=\"", device, "\"");
87   }
88   return ret;
89 }
90 
SummarizeNode() const91 string AttrSlice::SummarizeNode() const {
92   return ndef_ ? SummarizeNodeDef(*ndef_)
93                : strings::StrCat(
94                      "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
95 }
96 
DebugString() const97 string AttrSlice::DebugString() const {
98   std::vector<string> attr_key_vals;
99   attr_key_vals.reserve(attrs_->size());
100   for (const auto& it : *this) {
101     const string& name = it.first;
102     const AttrValue& attr_value = it.second;
103     attr_key_vals.push_back(
104         absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
105   }
106   return absl::StrJoin(attr_key_vals, ", ");
107 }
108 
SummarizeNodeDef(const NodeDef & node_def,int max_inputs_in_summary)109 string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
110   string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
111                                " = ", node_def.op(), "[");
112   strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
113   strings::StrAppend(&ret, "](");
114 
115   // Output inputs, including control inputs, verbatim.
116   bool first = true;
117   for (const string& input : node_def.input()) {
118     if (!first) strings::StrAppend(&ret, ", ");
119     first = false;
120     if (max_inputs_in_summary-- == 0) {
121       strings::StrAppend(&ret, "...");
122       break;
123     }
124     strings::StrAppend(&ret, input);
125   }
126   strings::StrAppend(&ret, ")");
127   return ret;
128 }
129 
SummarizeAttrs(const NodeDef & node_def)130 string SummarizeAttrs(const NodeDef& node_def) {
131   return SummarizeAttrsHelper(node_def, node_def.device());
132 }
133 
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)134 string FormatNodeDefForError(
135     StringPiece node_name, bool has_experimental_debug_info,
136     const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
137   return !has_experimental_debug_info ||
138                  experimental_debug_info.original_node_names().empty()
139              ? errors::FormatNodeNameForError(string(node_name))
140              : errors::FormatNodeNamesForError(
141                    experimental_debug_info.original_node_names());
142 }
143 
FormatNodeDefForError(const NodeDef & node_def)144 string FormatNodeDefForError(const NodeDef& node_def) {
145   return FormatNodeDefForError(node_def.name(),
146                                node_def.has_experimental_debug_info(),
147                                node_def.experimental_debug_info());
148 }
149 
Find(StringPiece attr_name) const150 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
151   // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
152   // requires that the keys used for lookups have type 'const string&'. Because
153   // this method takes a StringPiece, it is necessary to allocate a temporary
154   // string, copy attr_name to it, and then use that temporary string for the
155   // lookup. This causes an excessive number of short-lived allocations, and for
156   // large graphs, this can be a significant cost.
157   //
158   // Because most nodes have a small number of attributes, a simple linear scan
159   // is generally more efficient than a hashed lookup.  If google::protobuf::Map
160   // changes so that it supports efficient lookups using StringPiece instead of
161   // const string&, then this code could be changed to use attrs_->find() again.
162 
163   for (const auto& attr : *attrs_) {
164     if (attr.first == attr_name) {
165       return &attr.second;
166     }
167   }
168   return nullptr;
169 }
170 
FindByString(const string & attr_name) const171 const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
172   auto iter = attrs_->find(attr_name);
173   if (iter != attrs_->end()) {
174     return &iter->second;
175   } else {
176     return nullptr;
177   }
178 }
179 
Find(StringPiece attr_name,const AttrValue ** attr_value) const180 Status AttrSlice::Find(StringPiece attr_name,
181                        const AttrValue** attr_value) const {
182   *attr_value = Find(attr_name);
183   if (*attr_value != nullptr) {
184     return Status::OK();
185   }
186   Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
187   // Skip AttachDef for internal attrs since it is a little bit
188   // expensive and it is common for them to correctly not be included
189   // in a NodeDef.
190   if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
191     s = AttachDef(s, *ndef_);
192   }
193   return s;
194 }
195 
EqualAttrs(AttrSlice other,Scratch * scratch) const196 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
197   if (size() != other.size()) return false;
198 
199   for (const auto& attr : *other.attrs_) {
200     auto iter = attrs_->find(attr.first);
201     if (iter == attrs_->end()) return false;
202     // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
203     // TensorProto is a nonunique representation of Tensor.  This bug will go
204     // away once AttrSlice switches over to NodeInfo.
205     iter->second.SerializeToString(&scratch->a);
206     attr.second.SerializeToString(&scratch->b);
207     if (scratch->a != scratch->b) return false;
208   }
209   return true;
210 }
211 
212 // The ... is to allow the caller to inject some value validation code.  Use
213 // just ; if no additional validation code is needed.
214 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...)         \
215   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
216                      TYPE* value) {                                           \
217     const AttrValue* attr_value;                                              \
218     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
219     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE));             \
220     const auto& v = attr_value->FIELD();                                      \
221     __VA_ARGS__;                                                              \
222     *value = CAST;                                                            \
223     return Status::OK();                                                      \
224   }                                                                           \
225   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
226                      std::vector<TYPE>* value) {                              \
227     const AttrValue* attr_value;                                              \
228     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
229     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
230     value->reserve(attr_value->list().FIELD().size());                        \
231     for (const auto& v : attr_value->list().FIELD()) {                        \
232       __VA_ARGS__;                                                            \
233       value->APPEND_OP(CAST);                                                 \
234     }                                                                         \
235     return Status::OK();                                                      \
236   }
237 
238 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
239   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
240                       TYPE* value) {                                      \
241     const AttrValue* attr_value = attrs.Find(attr_name);                  \
242     if (attr_value == nullptr) {                                          \
243       return false;                                                       \
244     }                                                                     \
245     Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                  \
246     if (!s.ok()) {                                                        \
247       return false;                                                       \
248     }                                                                     \
249     const auto& v = attr_value->FIELD();                                  \
250     __VA_ARGS__;                                                          \
251     *value = CAST;                                                        \
252     return true;                                                          \
253   }                                                                       \
254   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
255                       std::vector<TYPE>* value) {                         \
256     const AttrValue* attr_value = attrs.Find(attr_name);                  \
257     if (attr_value == nullptr) {                                          \
258       return false;                                                       \
259     }                                                                     \
260     Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");      \
261     if (!s.ok()) {                                                        \
262       return false;                                                       \
263     }                                                                     \
264     value->reserve(attr_value->list().FIELD().size());                    \
265     for (const auto& v : attr_value->list().FIELD()) {                    \
266       __VA_ARGS__;                                                        \
267       value->APPEND_OP(CAST);                                             \
268     }                                                                     \
269     return true;                                                          \
270   }
271 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
272 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
273 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
274 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
275 DEFINE_GET_ATTR(int64, i, "int", emplace_back, v, ;)
276 DEFINE_TRY_GET_ATTR(int64, i, "int", emplace_back, v, ;)
277 DEFINE_GET_ATTR(
278     int32, i, "int", emplace_back, static_cast<int32>(v),
279     if (static_cast<int64>(static_cast<int32>(v)) != v) {
280       return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
281                                      " out of range for an int32");
282     })
283 DEFINE_TRY_GET_ATTR(
284     int32, i, "int", emplace_back, static_cast<int32>(v),
285     if (static_cast<int64>(static_cast<int32>(v)) != v) {
286       static int log_counter = 0;
287       if (log_counter < 10) {
288         log_counter++;
289         LOG(WARNING) << "Attr " << attr_name << " has value " << v
290                      << " out of range for an int32";
291       }
292       return false;
293     })
294 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
295 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
296 // std::vector<bool> specialization does not have emplace_back until
297 // c++14, so we have to use push_back (see
298 // http://en.cppreference.com/w/cpp/container/vector/emplace_back)
299 DEFINE_GET_ATTR(bool, b, "bool", push_back, v, ;)
300 DEFINE_TRY_GET_ATTR(bool, b, "bool", push_back, v, ;)
301 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
302                 ;)
303 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
304                     static_cast<DataType>(v),
305                     ;)
306 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
307 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
308                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
309 DEFINE_TRY_GET_ATTR(
310     TensorShape, shape, "shape", emplace_back, TensorShape(v),
311     if (!TensorShape::IsValidShape(v).ok()) {
312       static int log_counter = 0;
313       if (log_counter < 10) {
314         log_counter++;
315         LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
316                      << v.DebugString();
317       }
318       return false;
319     })
320 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
321                 PartialTensorShape(v),
322                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
323 DEFINE_GET_ATTR(
324     Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
325       return errors::InvalidArgument("Attr ", attr_name, " has value ",
326                                      v.ShortDebugString(),
327                                      " that can't be converted to a Tensor");
328     })
329 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
330 #undef DEFINE_GET_ATTR
331 
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)332 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
333   return node_def.attr().find(string(attr_name)) != node_def.attr().end();
334 }
335 
336 static const string& kEmptyString = *new string();
337 
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)338 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
339   const AttrValue* attr_value = attrs.Find(attr_name);
340   if (attr_value == nullptr) {
341     return kEmptyString;
342   }
343   Status s = AttrValueHasType(*attr_value, "string");
344   if (!s.ok()) {
345     return kEmptyString;
346   }
347   return attr_value->s();
348 }
349 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)350 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
351                     std::vector<const string*>* value) {
352   const AttrValue* attr_value = attrs.Find(attr_name);
353   if (attr_value == nullptr) {
354     return false;
355   }
356   Status s = AttrValueHasType(*attr_value, "list(string)");
357   if (!s.ok()) {
358     return false;
359   }
360   value->reserve(attr_value->list().s().size());
361   for (const auto& v : attr_value->list().s()) {
362     value->push_back(&v);
363   }
364   return true;
365 }
366 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)367 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
368                     std::vector<const TensorShapeProto*>* value) {
369   const AttrValue* attr_value = attrs.Find(attr_name);
370   if (attr_value == nullptr) {
371     return false;
372   }
373   Status s = AttrValueHasType(*attr_value, "list(shape)");
374   if (!s.ok()) {
375     return false;
376   }
377   value->reserve(attr_value->list().shape().size());
378   for (const auto& v : attr_value->list().shape()) {
379     value->push_back(&v);
380   }
381   return true;
382 }
383 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)384 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
385                    DataTypeVector* value) {
386   const AttrValue* attr_value;
387   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
388   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
389   for (const auto& v : attr_value->list().type()) {
390     value->push_back(static_cast<DataType>(v));
391   }
392   return Status::OK();
393 }
394 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)395 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
396                    const TensorProto** value) {
397   const AttrValue* attr_value;
398   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
399   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
400   *value = &attr_value->tensor();
401   return Status::OK();
402 }
403 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)404 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
405                     const TensorProto** value) {
406   const AttrValue* attr_value = attrs.Find(attr_name);
407   if (attr_value == nullptr) {
408     return false;
409   }
410   Status s = AttrValueHasType(*attr_value, "tensor");
411   if (!s.ok()) {
412     return false;
413   }
414   *value = &attr_value->tensor();
415   return true;
416 }
417 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)418 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
419                    const NameAttrList** value) {
420   const AttrValue* attr_value;
421   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
422   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
423   *value = &attr_value->func();
424   return Status::OK();
425 }
426 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)427 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
428                     const NameAttrList** value) {
429   const AttrValue* attr_value = attrs.Find(attr_name);
430   if (attr_value == nullptr) {
431     return false;
432   }
433   Status s = AttrValueHasType(*attr_value, "func");
434   if (!s.ok()) {
435     return false;
436   }
437   *value = &attr_value->func();
438   return true;
439 }
440 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,Padding * value)441 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
442                    Padding* value) {
443   string str_value;
444   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
445   return GetPaddingFromString(str_value, value);
446 }
447 
448 namespace {  // Helper for InOutTypesForNode().
449 
450 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)451 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
452                    const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
453   const int original_size = sig->size();
454   if (!arg_def.number_attr().empty()) {
455     // Same type repeated "repeats" times.
456     int64_t repeats = -1;
457     TF_RETURN_IF_ERROR(
458         GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
459     // We can't handle outputs that are larger than int32 sizes.
460     if (static_cast<int64>(static_cast<int32>(repeats)) != repeats) {
461       return errors::InvalidArgument("Number of outputs is too big: ", repeats);
462     }
463     if (repeats < 0) {
464       return errors::InvalidArgument("Value for number_attr() ", repeats,
465                                      " < 0");
466     }
467 
468     if (!arg_def.type_attr().empty()) {
469       DataType dtype;
470       TF_RETURN_IF_ERROR(
471           GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
472       for (int i = 0; i < repeats; ++i) {
473         sig->push_back(dtype);
474       }
475     } else if (arg_def.type() != DT_INVALID) {
476       for (int i = 0; i < repeats; ++i) {
477         sig->push_back(arg_def.type());
478       }
479     } else {
480       return errors::InvalidArgument("Missing type or type_attr field in ",
481                                      arg_def.ShortDebugString());
482     }
483   } else if (!arg_def.type_attr().empty()) {
484     const AttrValue* attr_value;
485     TF_RETURN_IF_ERROR(
486         AttrSlice(node_or_attrs).Find(arg_def.type_attr(), &attr_value));
487     sig->push_back(attr_value->type());
488   } else if (!arg_def.type_list_attr().empty()) {
489     const AttrValue* attr_value;
490     TF_RETURN_IF_ERROR(
491         AttrSlice(node_or_attrs).Find(arg_def.type_list_attr(), &attr_value));
492     for (int dtype : attr_value->list().type()) {
493       sig->push_back(static_cast<DataType>(dtype));
494     }
495   } else if (arg_def.type() != DT_INVALID) {
496     sig->push_back(arg_def.type());
497   } else {
498     return errors::InvalidArgument("No type fields in ",
499                                    arg_def.ShortDebugString());
500   }
501   if (arg_def.is_ref()) {
502     // For all types that were added by this function call, make them refs.
503     for (size_t i = original_size; i < sig->size(); ++i) {
504       if (IsRefType((*sig)[i])) {
505         return errors::InvalidArgument(
506             "Requested reference to a reference type: ",
507             arg_def.ShortDebugString());
508       }
509       (*sig)[i] = MakeRefType((*sig)[i]);
510     }
511   }
512   return Status::OK();
513 }
514 
515 }  // namespace
516 
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)517 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
518                         int input_port, DataType* input_type) {
519   DataTypeVector input_types;
520   for (const auto& arg : op_def.input_arg()) {
521     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
522     int input_types_size = input_types.size();
523     if (input_types_size > input_port) {
524       const DataType dtype = input_types[input_port];
525       *input_type = dtype;
526       return Status::OK();
527     }
528   }
529   return errors::InvalidArgument("Input ", input_port, " not found for node ",
530                                  node_def.name());
531 }
532 
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)533 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
534                          DataTypeVector* inputs) {
535   for (const auto& arg : op_def.input_arg()) {
536     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
537   }
538   return Status::OK();
539 }
540 
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)541 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
542                          int output_port, DataType* output_type) {
543   DataTypeVector output_types;
544   for (const auto& arg : op_def.output_arg()) {
545     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
546     int output_types_size = output_types.size();
547     if (output_types_size > output_port) {
548       const DataType dtype = output_types[output_port];
549       *output_type = dtype;
550       return Status::OK();
551     }
552   }
553   return errors::InvalidArgument("Output ", output_port, " not found for node ",
554                                  node_def.name());
555 }
556 
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)557 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
558                           DataTypeVector* outputs) {
559   for (const auto& arg : op_def.output_arg()) {
560     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
561   }
562   return Status::OK();
563 }
564 
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)565 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
566                           DataTypeVector* outputs) {
567   for (const auto& arg : op_def.output_arg()) {
568     TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
569   }
570   return Status::OK();
571 }
572 
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)573 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
574                          DataTypeVector* inputs, DataTypeVector* outputs) {
575   TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
576   return OutputTypesForNode(node_def, op_def, outputs);
577 }
578 
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)579 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
580                          int* num_outputs) {
581   DataTypeVector outputs;
582   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
583   *num_outputs = outputs.size();
584   return Status::OK();
585 }
586 
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)587 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
588   if (node_def.op() != op_def.name()) {
589     return errors::InvalidArgument(
590         "NodeDef op '", node_def.op(), "' does not match ",
591         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
592   }
593 
594   bool seen_control = false;
595   size_t num_inputs = 0;
596   // TODO(josh11b): Unify the input field validation.
597   for (const string& input : node_def.input()) {
598     if (absl::StartsWith(input, "^")) {
599       seen_control = true;
600       if (input.find(':') != string::npos) {
601         return errors::InvalidArgument("Control input '", input,
602                                        "' must not have ':' in NodeDef: ",
603                                        FormatNodeDefForError(node_def));
604       }
605     } else if (seen_control) {
606       return errors::InvalidArgument("Non-control input '", input,
607                                      "' after control input in NodeDef: ",
608                                      FormatNodeDefForError(node_def));
609     } else {
610       ++num_inputs;
611     }
612   }
613 
614   std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
615   for (const auto& attr : op_def.attr()) {
616     if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
617       return errors::InvalidArgument("OpDef has duplicate attr name '",
618                                      attr.name(),
619                                      "': ", SummarizeOpDef(op_def));
620     }
621   }
622   for (const auto& attr : node_def.attr()) {
623     // Allow internal optional attributes with names starting with "_".
624     if (absl::StartsWith(attr.first, "_")) {
625       continue;
626     }
627     auto iter = op_attrs.find(attr.first);
628     if (iter == op_attrs.end()) {
629       LOG_EVERY_N_SEC(ERROR, 5)
630           << "NodeDef mentions attribute " << attr.first
631           << " which is not in the op definition: " << SummarizeOpDef(op_def)
632           << " This may be expected if your graph generating binary is newer "
633           << " than this binary. Unknown attributes will be ignored."
634           << " NodeDef: " << FormatNodeDefForError(node_def);
635       continue;
636     }
637 
638     // If attr value is placeholder, do not check it.
639     if (attr.second.placeholder().empty()) {
640       TF_RETURN_WITH_CONTEXT_IF_ERROR(
641           ValidateAttrValue(attr.second, *iter->second),
642           "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
643           SummarizeOpDef(op_def));
644     }
645 
646     // Keep track of which attr names have (not) been found in the NodeDef.
647     op_attrs.erase(iter);
648   }
649 
650   // Were all attrs in the OpDef found in the NodeDef?
651   if (!op_attrs.empty()) {
652     string attrs;
653     for (const auto& attr_pair : op_attrs) {
654       if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
655       strings::StrAppend(&attrs, attr_pair.first);
656     }
657     return errors::InvalidArgument(
658         "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
659         "' from ", SummarizeOpDef(op_def),
660         "; NodeDef: ", FormatNodeDefForError(node_def));
661   }
662 
663   // Validate the number of inputs.
664   DataTypeVector inputs, outputs;
665   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
666 
667   if (num_inputs != inputs.size()) {
668     return errors::InvalidArgument(
669         "NodeDef expected inputs '", DataTypeVectorString(inputs),
670         "' do not match ", num_inputs, " inputs specified; ",
671         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
672   }
673 
674   return Status::OK();
675 }
676 
677 namespace {  // Helpers for NameRangesForNode()
678 
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)679 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
680                        const OpDef& op_def, int* num) {
681   if (!arg_def.number_attr().empty()) {
682     // Same type repeated "num" times.
683     return GetNodeAttr(attrs, arg_def.number_attr(), num);
684   } else if (!arg_def.type_list_attr().empty()) {
685     const AttrValue* attr_value;
686     TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
687     *num = attr_value->list().type_size();
688   } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
689     *num = 1;
690   } else {
691     return errors::InvalidArgument(
692         "Argument '", arg_def.name(),
693         "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
694   }
695   return Status::OK();
696 }
697 
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)698 Status NameRangesHelper(const AttrSlice& attrs,
699                         const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
700                         const OpDef& op_def, NameRangeMap* result) {
701   int start = 0;
702   int num;
703   for (const auto& arg : args) {
704     TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
705     (*result)[arg.name()] = std::make_pair(start, start + num);
706     start += num;
707   }
708   return Status::OK();
709 }
710 
711 }  // namespace
712 
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)713 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
714                          NameRangeMap* inputs, NameRangeMap* outputs) {
715   if (inputs != nullptr) {
716     TF_RETURN_IF_ERROR(
717         NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
718   }
719   if (outputs != nullptr) {
720     return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
721   }
722   return Status::OK();
723 }
724 
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)725 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
726   for (const auto& attr_def : op_def.attr()) {
727     AttrSlice attrs(*node_def);
728     if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
729       AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
730     }
731   }
732 }
733 
734 namespace {
735 
736 using ::tensorflow::tstring;
737 using ::tensorflow::strings::Scanner;
738 
IsValidNodeName(StringPiece sp)739 bool IsValidNodeName(StringPiece sp) {
740   Scanner scanner(sp);
741   scanner.One(Scanner::LETTER_DIGIT_DOT)
742       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
743 
744   while (true) {
745     if (!scanner.GetResult())  // Some error in previous iteration.
746       return false;
747     if (scanner.empty())  // No error, but nothing left, good.
748       return true;
749 
750     // Absorb another name/namespace, starting with a '>'
751     scanner.One(Scanner::RANGLE)
752         .One(Scanner::LETTER_DIGIT_DOT)
753         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
754   }
755 }
756 
IsValidDataInputName(StringPiece sp)757 bool IsValidDataInputName(StringPiece sp) {
758   // Data inputs are op_name, op_name:0, or op_name:12345.
759   Scanner scan(sp);
760   scan.One(Scanner::LETTER_DIGIT_DOT)
761       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
762 
763   while (true) {
764     if (!scan.GetResult())  // Some error in previous iteration.
765       return false;
766     if (scan.empty())  // No error, but nothing left, good.
767       return true;
768 
769     if (scan.Peek() == ':') {  // Absorb identifier after the colon
770       scan.OneLiteral(":");
771       if (scan.Peek() == '0') {
772         scan.OneLiteral("0");  // :0
773       } else {
774         scan.Many(Scanner::DIGIT);  // :[1-9][0-9]*
775       }
776     } else {
777       // Absorb another name/namespace, starting with a '>'
778       scan.One(Scanner::RANGLE)
779           .One(Scanner::LETTER_DIGIT_DOT)
780           .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
781     }
782   }
783 }
784 
IsValidControlInputName(StringPiece sp)785 bool IsValidControlInputName(StringPiece sp) {
786   Scanner scan(sp);
787   scan.OneLiteral("^")
788       .One(Scanner::LETTER_DIGIT_DOT)
789       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
790 
791   while (true) {
792     if (!scan.GetResult())  // Some error in previous iteration.
793       return false;
794     if (scan.empty())  // No error, but nothing left, good.
795       return true;
796 
797     // Absorb another name/namespace, starting with a '>'
798     scan.One(Scanner::RANGLE)
799         .One(Scanner::LETTER_DIGIT_DOT)
800         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
801   }
802 }
803 
804 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
805 
806 }  // namespace
807 
ValidateOpInput(const string & input_name,bool * is_control_input)808 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
809   *is_control_input = false;
810   if (IsValidDataInputName(input_name)) {
811     return Status::OK();
812   } else if (IsValidControlInputName(input_name)) {
813     *is_control_input = true;
814     return Status::OK();
815   } else {
816     return errors::InvalidArgument("Illegal op input name '", input_name, "'");
817   }
818 }
819 
ValidateNodeName(const string & node_name)820 Status ValidateNodeName(const string& node_name) {
821   if (IsValidNodeName(node_name)) {
822     return Status::OK();
823   } else {
824     return errors::InvalidArgument("Illegal op name '", node_name, "'");
825   }
826 }
827 
ValidateExternalNodeDefSyntax(const NodeDef & node_def)828 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
829   Status s = ValidateNodeName(node_def.name());
830   if (!s.ok()) {
831     return AttachDef(s, node_def);
832   }
833   bool in_control_inputs = false;
834   for (const string& input_name : node_def.input()) {
835     bool is_control_input;
836     s = ValidateOpInput(input_name, &is_control_input);
837     if (!s.ok()) {
838       return AttachDef(s, node_def);
839     }
840 
841     if (in_control_inputs && !is_control_input) {
842       return AttachDef(errors::InvalidArgument(
843                            "All control inputs must follow all data inputs"),
844                        node_def);
845     }
846     in_control_inputs = is_control_input;
847   }
848   return Status::OK();
849 }
850 
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)851 Status AttachDef(const Status& status, const NodeDef& node_def,
852                  bool allow_multiple_formatted_node) {
853   Status ret = status;
854   string node_error;
855   if (!allow_multiple_formatted_node &&
856       status.error_message().find("{{node ") != string::npos) {
857     node_error = node_def.name();
858   } else {
859     node_error = FormatNodeDefForError(node_def);
860   }
861   errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
862   return ret;
863 }
864 
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)865 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
866   node_def->mutable_attr()->insert(
867       AttrValueMap::value_type(string(name), value));
868 }
869 
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)870 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
871   (*node_def->mutable_attr())[string(name)] = std::move(value);
872 }
873 
874 #define ADD_NODE_ATTR(T)                                           \
875   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
876     AttrValue attr_value;                                          \
877     SetAttrValue(value, &attr_value);                              \
878     AddNodeAttr(name, attr_value, node_def);                       \
879   }
880 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)881 ADD_NODE_ATTR(const char*)
882 ADD_NODE_ATTR(int32_t)
883 ADD_NODE_ATTR(int64_t)
884 ADD_NODE_ATTR(float)
885 ADD_NODE_ATTR(double)
886 ADD_NODE_ATTR(bool)
887 ADD_NODE_ATTR(DataType)
888 ADD_NODE_ATTR(const PartialTensorShape&)
889 ADD_NODE_ATTR(const Tensor&)
890 ADD_NODE_ATTR(const TensorProto&)
891 ADD_NODE_ATTR(const NameAttrList&)
892 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
893 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
894 ADD_NODE_ATTR(gtl::ArraySlice<string>)
895 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
896 ADD_NODE_ATTR(gtl::ArraySlice<int64>)
897 ADD_NODE_ATTR(gtl::ArraySlice<float>)
898 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
899 ADD_NODE_ATTR(const std::vector<bool>&)
900 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
901 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
902 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
903 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
904 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
905 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
906 #undef ADD_NODE_ATTR
907 
908 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
909   map->insert(AttrValueMap::value_type(string(name), value));
910 }
911 
912 #define ADD_ATTR(T)                                            \
913   void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
914     AttrValue attr_value;                                      \
915     SetAttrValue(value, &attr_value);                          \
916     AddAttr(name, attr_value, map);                            \
917   }
ADD_ATTR(bool)918 ADD_ATTR(bool)
919 #undef ADD_ATTR
920 
921 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
922                                 NodeDef* node_def, bool uniquify_frame_name) {
923   node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
924 
925   // Update frame name to avoid multiple LoopCond nodes in one frame.
926   if (uniquify_frame_name &&
927       (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
928     string frame_name;
929     TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
930     AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
931     frame_name = strings::StrCat(prefix, frame_name, suffix);
932     attr.set_s(frame_name);
933   }
934 
935   return Status::OK();
936 }
937 
MaybeAddPrefixToColocationConstraints(const std::unordered_set<string> & match,StringPiece prefix,NodeDef * node_def)938 Status MaybeAddPrefixToColocationConstraints(
939     const std::unordered_set<string>& match, StringPiece prefix,
940     NodeDef* node_def) {
941   auto attr = node_def->mutable_attr()->find(kColocationAttrName);
942   if (attr == node_def->mutable_attr()->end()) {
943     return Status::OK();
944   }
945   auto constraints_list = attr->second.mutable_list();
946   auto constraints_size = constraints_list->s_size();
947   for (size_t i = 0; i < constraints_size; ++i) {
948     StringPiece original(constraints_list->s(i));
949     if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
950       if (match.find(string(original)) != match.end()) {
951         (*constraints_list->mutable_s(i)) =
952             strings::StrCat(kColocationGroupPrefix, prefix, original);
953       }
954     }
955   }
956   return Status::OK();
957 }
958 
MaybeUpdateColocationConstraintsWithMap(const std::map<absl::string_view,absl::string_view> & node_name_map,NodeDef * node_def)959 Status MaybeUpdateColocationConstraintsWithMap(
960     const std::map<absl::string_view, absl::string_view>& node_name_map,
961     NodeDef* node_def) {
962   auto attr = node_def->mutable_attr()->find(kColocationAttrName);
963   if (attr == node_def->mutable_attr()->end()) {
964     return Status::OK();
965   }
966   auto constraints_list = attr->second.mutable_list();
967   auto constraints_size = constraints_list->s_size();
968   for (size_t i = 0; i < constraints_size; ++i) {
969     StringPiece original(constraints_list->s(i));
970     if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
971       if (node_name_map.find(original) != node_name_map.end()) {
972         (*constraints_list->mutable_s(i)) =
973             strings::StrCat(kColocationGroupPrefix, node_name_map.at(original));
974       }
975     }
976   }
977   return Status::OK();
978 }
979 
980 }  // namespace tensorflow
981