• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def_util.h"
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "tensorflow/core/framework/attr_value.pb.h"
25 #include "tensorflow/core/framework/attr_value_util.h"
26 #include "tensorflow/core/framework/node_def.pb.h"
27 #include "tensorflow/core/framework/op_def.pb.h"
28 #include "tensorflow/core/framework/op_def_util.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/tensor_shape.pb.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/lib/gtl/map_util.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/platform/scanner.h"
38 #include "tensorflow/core/platform/status.h"
39 #include "tensorflow/core/platform/strcat.h"
40 #include "tensorflow/core/platform/stringpiece.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace tensorflow {
44 
45 const char* const kColocationAttrName = "_class";
46 const char* const kColocationGroupPrefix = "loc:@";
47 // For TPU distributed rewrite, TPU args are collected and "staged" on the local
48 // host using an IdentityN TF op. Some args may result from a remote source.
49 // When all arg tensors are available, the TPUExecute op can be inovoked. See
50 // DistributedTPURewritePass for more details.
51 const char* const kTpuExecuteStagingOp = "IdentityN";
52 const char* const kTpuExecuteStagingNodeName = "_variable_copy";
53 
AttrSlice()54 AttrSlice::AttrSlice() : ndef_(nullptr) {
55   static const AttrValueMap* const kEmptyAttrValueMap = new AttrValueMap;
56   attrs_ = kEmptyAttrValueMap;
57 }
58 
59 // Do not cache the map field reference because that may be invalidated on
60 // Clear.
AttrSlice(const NodeDef & node_def)61 AttrSlice::AttrSlice(const NodeDef& node_def)
62     : ndef_(&node_def), attrs_(nullptr) {}
63 
AttrSlice(const AttrValueMap * a)64 AttrSlice::AttrSlice(const AttrValueMap* a) : ndef_(nullptr), attrs_(a) {}
65 
SummarizeAttrsHelper(AttrSlice attrs,StringPiece device)66 string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device) {
67   string ret;
68 
69   // We sort the attrs so the output is deterministic.
70   std::vector<string> attr_names;
71   attr_names.reserve(attrs.size());
72   for (const auto& attr : attrs) {
73     attr_names.push_back(attr.first);
74   }
75   std::sort(attr_names.begin(), attr_names.end());
76   bool first = true;
77   for (const string& attr_name : attr_names) {
78     if (!first) strings::StrAppend(&ret, ", ");
79     first = false;
80     strings::StrAppend(&ret, attr_name, "=",
81                        SummarizeAttrValue(*attrs.Find(attr_name)));
82   }
83 
84   // Consider the device to be a final attr with name "_device".
85   if (!device.empty()) {
86     if (!first) strings::StrAppend(&ret, ", ");
87     first = false;
88     strings::StrAppend(&ret, "_device=\"", device, "\"");
89   }
90   return ret;
91 }
92 
SummarizeNode() const93 string AttrSlice::SummarizeNode() const {
94   return ndef_ ? SummarizeNodeDef(*ndef_)
95                : strings::StrCat(
96                      "[", SummarizeAttrsHelper(*this, StringPiece()), "]");
97 }
98 
DebugString() const99 string AttrSlice::DebugString() const {
100   std::vector<string> attr_key_vals;
101   attr_key_vals.reserve(attrs()->size());
102   for (const auto& it : *this) {
103     const string& name = it.first;
104     const AttrValue& attr_value = it.second;
105     attr_key_vals.push_back(
106         absl::StrCat(name, "=", SummarizeAttrValue(attr_value)));
107   }
108   return absl::StrJoin(attr_key_vals, ", ");
109 }
110 
SummarizeNodeDef(const NodeDef & node_def,int max_inputs_in_summary)111 string SummarizeNodeDef(const NodeDef& node_def, int max_inputs_in_summary) {
112   string ret = strings::StrCat(errors::FormatNodeNameForError(node_def.name()),
113                                " = ", node_def.op(), "[");
114   strings::StrAppend(&ret, SummarizeAttrsHelper(node_def, node_def.device()));
115   strings::StrAppend(&ret, "](");
116 
117   // Output inputs, including control inputs, verbatim.
118   bool first = true;
119   for (const string& input : node_def.input()) {
120     if (!first) strings::StrAppend(&ret, ", ");
121     first = false;
122     if (max_inputs_in_summary-- == 0) {
123       strings::StrAppend(&ret, "...");
124       break;
125     }
126     strings::StrAppend(&ret, input);
127   }
128   strings::StrAppend(&ret, ")");
129   return ret;
130 }
131 
SummarizeAttrs(const NodeDef & node_def)132 string SummarizeAttrs(const NodeDef& node_def) {
133   return SummarizeAttrsHelper(node_def, node_def.device());
134 }
135 
FormatNodeDefForError(StringPiece node_name,bool has_experimental_debug_info,const NodeDef_ExperimentalDebugInfo & experimental_debug_info)136 string FormatNodeDefForError(
137     StringPiece node_name, bool has_experimental_debug_info,
138     const NodeDef_ExperimentalDebugInfo& experimental_debug_info) {
139   return !has_experimental_debug_info ||
140                  experimental_debug_info.original_node_names().empty()
141              ? errors::FormatNodeNameForError(string(node_name))
142              : errors::FormatOriginalNodeLocationForError(
143                    experimental_debug_info.original_node_names(),
144                    experimental_debug_info.original_func_names());
145 }
146 
FormatNodeDefForError(const NodeDef & node_def)147 string FormatNodeDefForError(const NodeDef& node_def) {
148   return FormatNodeDefForError(node_def.name(),
149                                node_def.has_experimental_debug_info(),
150                                node_def.experimental_debug_info());
151 }
152 
Find(StringPiece attr_name) const153 const AttrValue* AttrSlice::Find(StringPiece attr_name) const {
154   // Currently, the collection used for NodeDef::attr() (google::protobuf::Map)
155   // requires that the keys used for lookups have type 'const string&'. Because
156   // this method takes a StringPiece, it is necessary to allocate a temporary
157   // string, copy attr_name to it, and then use that temporary string for the
158   // lookup. This causes an excessive number of short-lived allocations, and for
159   // large graphs, this can be a significant cost.
160   //
161   // Because most nodes have a small number of attributes, a simple linear scan
162   // is generally more efficient than a hashed lookup.  If google::protobuf::Map
163   // changes so that it supports efficient lookups using StringPiece instead of
164   // const string&, then this code could be changed to use attrs()->find()
165   // again.
166 
167   for (const auto& attr : *attrs()) {
168     if (attr.first == attr_name) {
169       return &attr.second;
170     }
171   }
172   return nullptr;
173 }
174 
FindByString(const string & attr_name) const175 const AttrValue* AttrSlice::FindByString(const string& attr_name) const {
176   auto iter = attrs()->find(attr_name);
177   if (iter != attrs()->end()) {
178     return &iter->second;
179   } else {
180     return nullptr;
181   }
182 }
183 
CheckFind(StringPiece attr_name,const AttrValue * attr_value) const184 Status AttrSlice::CheckFind(StringPiece attr_name,
185                             const AttrValue* attr_value) const {
186   if (attr_value != nullptr) {
187     return OkStatus();
188   }
189   Status s = errors::NotFound("No attr named '", attr_name, "' in NodeDef:");
190   // Skip AttachDef for internal attrs since it is a little bit
191   // expensive and it is common for them to correctly not be included
192   // in a NodeDef.
193   if (!absl::StartsWith(attr_name, "_") && ndef_ != nullptr) {
194     s = AttachDef(s, *ndef_);
195   }
196   return s;
197 }
198 
Find(StringPiece attr_name,const AttrValue ** attr_value) const199 Status AttrSlice::Find(StringPiece attr_name,
200                        const AttrValue** attr_value) const {
201   *attr_value = Find(attr_name);
202   return CheckFind(attr_name, *attr_value);
203 }
204 
FindByString(const string & attr_name,const AttrValue ** attr_value) const205 Status AttrSlice::FindByString(const string& attr_name,
206                                const AttrValue** attr_value) const {
207   *attr_value = FindByString(attr_name);
208   return CheckFind(attr_name, *attr_value);
209 }
210 
EqualAttrs(AttrSlice other,Scratch * scratch) const211 bool AttrSlice::EqualAttrs(AttrSlice other, Scratch* scratch) const {
212   if (size() != other.size()) return false;
213 
214   for (const auto& attr : *other.attrs()) {
215     auto iter = attrs()->find(attr.first);
216     if (iter == attrs()->end()) return false;
217     // TODO(irving): Comparing AttrValues by proto is slightly buggy, since
218     // TensorProto is a nonunique representation of Tensor.  This bug will go
219     // away once AttrSlice switches over to NodeInfo.
220     iter->second.SerializeToString(&scratch->a);
221     attr.second.SerializeToString(&scratch->b);
222     if (scratch->a != scratch->b) return false;
223   }
224   return true;
225 }
226 
227 // The ... is to allow the caller to inject some value validation code.  Use
228 // just ; if no additional validation code is needed.
229 #define DEFINE_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...)         \
230   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
231                      TYPE* value) {                                           \
232     const AttrValue* attr_value;                                              \
233     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
234     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, ATTR_TYPE));             \
235     const auto& v = attr_value->FIELD();                                      \
236     __VA_ARGS__;                                                              \
237     *value = CAST;                                                            \
238     return OkStatus();                                                        \
239   }                                                                           \
240   Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,           \
241                      std::vector<TYPE>* value) {                              \
242     const AttrValue* attr_value;                                              \
243     TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));                   \
244     TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")")); \
245     value->reserve(attr_value->list().FIELD().size());                        \
246     for (const auto& v : attr_value->list().FIELD()) {                        \
247       __VA_ARGS__;                                                            \
248       value->APPEND_OP(CAST);                                                 \
249     }                                                                         \
250     return OkStatus();                                                        \
251   }
252 
253 #define DEFINE_TRY_GET_ATTR(TYPE, FIELD, ATTR_TYPE, APPEND_OP, CAST, ...) \
254   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
255                       TYPE* value) {                                      \
256     const AttrValue* attr_value = attrs.Find(attr_name);                  \
257     if (attr_value == nullptr) {                                          \
258       return false;                                                       \
259     }                                                                     \
260     Status s = AttrValueHasType(*attr_value, ATTR_TYPE);                  \
261     if (!s.ok()) {                                                        \
262       return false;                                                       \
263     }                                                                     \
264     const auto& v = attr_value->FIELD();                                  \
265     __VA_ARGS__;                                                          \
266     *value = CAST;                                                        \
267     return true;                                                          \
268   }                                                                       \
269   bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,      \
270                       std::vector<TYPE>* value) {                         \
271     const AttrValue* attr_value = attrs.Find(attr_name);                  \
272     if (attr_value == nullptr) {                                          \
273       return false;                                                       \
274     }                                                                     \
275     Status s = AttrValueHasType(*attr_value, "list(" ATTR_TYPE ")");      \
276     if (!s.ok()) {                                                        \
277       return false;                                                       \
278     }                                                                     \
279     value->reserve(attr_value->list().FIELD().size());                    \
280     for (const auto& v : attr_value->list().FIELD()) {                    \
281       __VA_ARGS__;                                                        \
282       value->APPEND_OP(CAST);                                             \
283     }                                                                     \
284     return true;                                                          \
285   }
286 DEFINE_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
287 DEFINE_TRY_GET_ATTR(tstring, s, "string", emplace_back, v, ;)
288 DEFINE_GET_ATTR(string, s, "string", emplace_back, v, ;)
289 DEFINE_TRY_GET_ATTR(string, s, "string", emplace_back, v, ;)
290 DEFINE_GET_ATTR(int64_t, i, "int", emplace_back, v, ;)
291 DEFINE_TRY_GET_ATTR(int64_t, i, "int", emplace_back, v, ;)
292 DEFINE_GET_ATTR(
293     int32, i, "int", emplace_back, static_cast<int32>(v),
294     if (static_cast<int64_t>(static_cast<int32>(v)) != v) {
295       return errors::InvalidArgument("Attr ", attr_name, " has value ", v,
296                                      " out of range for an int32");
297     })
298 DEFINE_TRY_GET_ATTR(
299     int32, i, "int", emplace_back, static_cast<int32>(v),
300     if (static_cast<int64_t>(static_cast<int32>(v)) != v) {
301       static int log_counter = 0;
302       if (log_counter < 10) {
303         log_counter++;
304         LOG(WARNING) << "Attr " << attr_name << " has value " << v
305                      << " out of range for an int32";
306       }
307       return false;
308     })
309 DEFINE_GET_ATTR(float, f, "float", emplace_back, v, ;)
310 DEFINE_TRY_GET_ATTR(float, f, "float", emplace_back, v, ;)
311 DEFINE_GET_ATTR(bool, b, "bool", emplace_back, v, ;)
312 DEFINE_TRY_GET_ATTR(bool, b, "bool", emplace_back, v, ;)
313 DEFINE_GET_ATTR(DataType, type, "type", emplace_back, static_cast<DataType>(v),
314                 ;)
315 DEFINE_TRY_GET_ATTR(DataType, type, "type", emplace_back,
316                     static_cast<DataType>(v),
317                     ;)
318 DEFINE_GET_ATTR(TensorShapeProto, shape, "shape", emplace_back, v, ;)
319 DEFINE_GET_ATTR(TensorShape, shape, "shape", emplace_back, TensorShape(v),
320                 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(v));)
321 DEFINE_TRY_GET_ATTR(
322     TensorShape, shape, "shape", emplace_back, TensorShape(v),
323     if (!TensorShape::IsValidShape(v).ok()) {
324       static int log_counter = 0;
325       if (log_counter < 10) {
326         log_counter++;
327         LOG(WARNING) << "Attr " << attr_name << " has invalid shape value "
328                      << v.DebugString();
329       }
330       return false;
331     })
332 DEFINE_GET_ATTR(PartialTensorShape, shape, "shape", emplace_back,
333                 PartialTensorShape(v),
334                 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(v));)
335 DEFINE_GET_ATTR(
336     Tensor, tensor, "tensor", emplace_back, t, Tensor t; if (!t.FromProto(v)) {
337       return errors::InvalidArgument("Attr ", attr_name, " has value ",
338                                      v.ShortDebugString(),
339                                      " that can't be converted to a Tensor");
340     })
341 DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
342 #undef DEFINE_GET_ATTR
343 
HasNodeAttr(const NodeDef & node_def,StringPiece attr_name)344 bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
345   return node_def.attr().find(string(attr_name)) != node_def.attr().end();
346 }
347 
348 static const string& kEmptyString = *new string();
349 
GetNodeAttrString(const AttrSlice & attrs,StringPiece attr_name)350 const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name) {
351   const AttrValue* attr_value = attrs.Find(attr_name);
352   if (attr_value == nullptr) {
353     return kEmptyString;
354   }
355   Status s = AttrValueHasType(*attr_value, "string");
356   if (!s.ok()) {
357     return kEmptyString;
358   }
359   return attr_value->s();
360 }
361 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const string * > * value)362 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
363                     std::vector<const string*>* value) {
364   const AttrValue* attr_value = attrs.Find(attr_name);
365   if (attr_value == nullptr) {
366     return false;
367   }
368   Status s = AttrValueHasType(*attr_value, "list(string)");
369   if (!s.ok()) {
370     return false;
371   }
372   value->reserve(attr_value->list().s().size());
373   for (const auto& v : attr_value->list().s()) {
374     value->push_back(&v);
375   }
376   return true;
377 }
378 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,std::vector<const TensorShapeProto * > * value)379 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
380                     std::vector<const TensorShapeProto*>* value) {
381   const AttrValue* attr_value = attrs.Find(attr_name);
382   if (attr_value == nullptr) {
383     return false;
384   }
385   Status s = AttrValueHasType(*attr_value, "list(shape)");
386   if (!s.ok()) {
387     return false;
388   }
389   value->reserve(attr_value->list().shape().size());
390   for (const auto& v : attr_value->list().shape()) {
391     value->push_back(&v);
392   }
393   return true;
394 }
395 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,DataTypeVector * value)396 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
397                    DataTypeVector* value) {
398   const AttrValue* attr_value;
399   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
400   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "list(type)"));
401   for (const auto& v : attr_value->list().type()) {
402     value->push_back(static_cast<DataType>(v));
403   }
404   return OkStatus();
405 }
406 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)407 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
408                    const TensorProto** value) {
409   const AttrValue* attr_value;
410   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
411   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "tensor"));
412   *value = &attr_value->tensor();
413   return OkStatus();
414 }
415 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const TensorProto ** value)416 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
417                     const TensorProto** value) {
418   const AttrValue* attr_value = attrs.Find(attr_name);
419   if (attr_value == nullptr) {
420     return false;
421   }
422   Status s = AttrValueHasType(*attr_value, "tensor");
423   if (!s.ok()) {
424     return false;
425   }
426   *value = &attr_value->tensor();
427   return true;
428 }
429 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)430 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
431                    const NameAttrList** value) {
432   const AttrValue* attr_value;
433   TF_RETURN_IF_ERROR(attrs.Find(attr_name, &attr_value));
434   TF_RETURN_IF_ERROR(AttrValueHasType(*attr_value, "func"));
435   *value = &attr_value->func();
436   return OkStatus();
437 }
438 
TryGetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,const NameAttrList ** value)439 bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
440                     const NameAttrList** value) {
441   const AttrValue* attr_value = attrs.Find(attr_name);
442   if (attr_value == nullptr) {
443     return false;
444   }
445   Status s = AttrValueHasType(*attr_value, "func");
446   if (!s.ok()) {
447     return false;
448   }
449   *value = &attr_value->func();
450   return true;
451 }
452 
GetNodeAttr(const AttrSlice & attrs,StringPiece attr_name,Padding * value)453 Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
454                    Padding* value) {
455   string str_value;
456   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_name, &str_value));
457   return GetPaddingFromString(str_value, value);
458 }
459 
460 namespace {  // Helper for InOutTypesForNode().
461 
462 template <class NodeDefOrAttrSlice>
AddArgToSig(const NodeDefOrAttrSlice & node_or_attrs,const OpDef::ArgDef & arg_def,DataTypeVector * sig)463 Status AddArgToSig(const NodeDefOrAttrSlice& node_or_attrs,
464                    const OpDef::ArgDef& arg_def, DataTypeVector* sig) {
465   const int original_size = sig->size();
466   if (!arg_def.number_attr().empty()) {
467     // Same type repeated "repeats" times.
468     int64_t repeats = -1;
469     TF_RETURN_IF_ERROR(
470         GetNodeAttr(node_or_attrs, arg_def.number_attr(), &repeats));
471     // We can't handle outputs that are larger than int32 sizes.
472     if (static_cast<int64_t>(static_cast<int32>(repeats)) != repeats) {
473       return errors::InvalidArgument("Number of outputs is too big: ", repeats);
474     }
475     if (repeats < 0) {
476       return errors::InvalidArgument("Value for number_attr() ", repeats,
477                                      " < 0");
478     }
479 
480     if (!arg_def.type_attr().empty()) {
481       DataType dtype;
482       TF_RETURN_IF_ERROR(
483           GetNodeAttr(node_or_attrs, arg_def.type_attr(), &dtype));
484       for (int i = 0; i < repeats; ++i) {
485         sig->push_back(dtype);
486       }
487     } else if (arg_def.type() != DT_INVALID) {
488       for (int i = 0; i < repeats; ++i) {
489         sig->push_back(arg_def.type());
490       }
491     } else {
492       return errors::InvalidArgument("Missing type or type_attr field in ",
493                                      arg_def.ShortDebugString());
494     }
495   } else if (!arg_def.type_attr().empty()) {
496     const AttrValue* attr_value;
497     TF_RETURN_IF_ERROR(AttrSlice(node_or_attrs)
498                            .FindByString(arg_def.type_attr(), &attr_value));
499     sig->push_back(attr_value->type());
500   } else if (!arg_def.type_list_attr().empty()) {
501     const AttrValue* attr_value;
502     TF_RETURN_IF_ERROR(
503         AttrSlice(node_or_attrs)
504             .FindByString(arg_def.type_list_attr(), &attr_value));
505     for (int dtype : attr_value->list().type()) {
506       sig->push_back(static_cast<DataType>(dtype));
507     }
508   } else if (arg_def.type() != DT_INVALID) {
509     sig->push_back(arg_def.type());
510   } else {
511     return errors::InvalidArgument("No type fields in ",
512                                    arg_def.ShortDebugString());
513   }
514   if (arg_def.is_ref()) {
515     // For all types that were added by this function call, make them refs.
516     for (size_t i = original_size; i < sig->size(); ++i) {
517       if (IsRefType((*sig)[i])) {
518         return errors::InvalidArgument(
519             "Requested reference to a reference type: ",
520             arg_def.ShortDebugString());
521       }
522       (*sig)[i] = MakeRefType((*sig)[i]);
523     }
524   }
525   return OkStatus();
526 }
527 
528 }  // namespace
529 
InputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int input_port,DataType * input_type)530 Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
531                         int input_port, DataType* input_type) {
532   DataTypeVector input_types;
533   for (const auto& arg : op_def.input_arg()) {
534     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &input_types));
535     int input_types_size = input_types.size();
536     if (input_types_size > input_port) {
537       const DataType dtype = input_types[input_port];
538       *input_type = dtype;
539       return OkStatus();
540     }
541   }
542   return errors::InvalidArgument("Input ", input_port, " not found for node ",
543                                  node_def.name());
544 }
545 
InputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs)546 Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
547                          DataTypeVector* inputs) {
548   for (const auto& arg : op_def.input_arg()) {
549     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
550   }
551   return OkStatus();
552 }
553 
OutputTypeForNode(const NodeDef & node_def,const OpDef & op_def,int output_port,DataType * output_type)554 Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
555                          int output_port, DataType* output_type) {
556   DataTypeVector output_types;
557   for (const auto& arg : op_def.output_arg()) {
558     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, &output_types));
559     int output_types_size = output_types.size();
560     if (output_types_size > output_port) {
561       const DataType dtype = output_types[output_port];
562       *output_type = dtype;
563       return OkStatus();
564     }
565   }
566   return errors::InvalidArgument("Output ", output_port, " not found for node ",
567                                  node_def.name());
568 }
569 
OutputTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * outputs)570 Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
571                           DataTypeVector* outputs) {
572   for (const auto& arg : op_def.output_arg()) {
573     TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, outputs));
574   }
575   return OkStatus();
576 }
577 
OutputTypesForNode(const AttrSlice & attrs,const OpDef & op_def,DataTypeVector * outputs)578 Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
579                           DataTypeVector* outputs) {
580   for (const auto& arg : op_def.output_arg()) {
581     TF_RETURN_IF_ERROR(AddArgToSig(attrs, arg, outputs));
582   }
583   return OkStatus();
584 }
585 
InOutTypesForNode(const NodeDef & node_def,const OpDef & op_def,DataTypeVector * inputs,DataTypeVector * outputs)586 Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
587                          DataTypeVector* inputs, DataTypeVector* outputs) {
588   TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
589   return OutputTypesForNode(node_def, op_def, outputs);
590 }
591 
NumOutputsForNode(const NodeDef & node_def,const OpDef & op_def,int * num_outputs)592 Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
593                          int* num_outputs) {
594   DataTypeVector outputs;
595   TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
596   *num_outputs = outputs.size();
597   return OkStatus();
598 }
599 
OpPortIdToArgId(const NodeDef & node,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,int port_id)600 int OpPortIdToArgId(const NodeDef& node,
601                     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
602                     int port_id) {
603   for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
604     if (port_id < 0) {
605       return -1;
606     } else if (port_id == 0) {
607       return arg_id;
608     }
609 
610     // Default is 1 port per arg.
611     int n = 1;
612 
613     const auto& arg = args.Get(arg_id);
614     if (!arg.number_attr().empty()) {
615       n = node.attr().at(arg.number_attr()).i();
616     } else if (!arg.type_list_attr().empty()) {
617       n = node.attr().at(arg.type_list_attr()).list().type_size();
618     }
619 
620     if (n < 0) {
621       // This should never happen.
622       DCHECK_GE(n, 0);
623       return -1;
624     } else if (port_id < n) {
625       return arg_id;
626     }
627     port_id -= n;
628   }
629 
630   return -1;
631 }
632 
ValidateNodeDef(const NodeDef & node_def,const OpDef & op_def)633 Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
634   if (node_def.op() != op_def.name()) {
635     return errors::InvalidArgument(
636         "NodeDef op '", node_def.op(), "' does not match ",
637         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
638   }
639 
640   bool seen_control = false;
641   size_t num_inputs = 0;
642   // TODO(josh11b): Unify the input field validation.
643   for (const string& input : node_def.input()) {
644     if (absl::StartsWith(input, "^")) {
645       seen_control = true;
646       if (input.find(':') != string::npos) {
647         return errors::InvalidArgument("Control input '", input,
648                                        "' must not have ':' in NodeDef: ",
649                                        FormatNodeDefForError(node_def));
650       }
651     } else if (seen_control) {
652       return errors::InvalidArgument("Non-control input '", input,
653                                      "' after control input in NodeDef: ",
654                                      FormatNodeDefForError(node_def));
655     } else {
656       ++num_inputs;
657     }
658   }
659 
660   std::unordered_map<string, const OpDef::AttrDef*> op_attrs;
661   for (const auto& attr : op_def.attr()) {
662     if (!gtl::InsertIfNotPresent(&op_attrs, attr.name(), &attr)) {
663       return errors::InvalidArgument("OpDef has duplicate attr name '",
664                                      attr.name(),
665                                      "': ", SummarizeOpDef(op_def));
666     }
667   }
668   for (const auto& attr : node_def.attr()) {
669     // Allow internal optional attributes with names starting with "_".
670     if (absl::StartsWith(attr.first, "_")) {
671       continue;
672     }
673     auto iter = op_attrs.find(attr.first);
674     if (iter == op_attrs.end()) {
675       LOG_EVERY_N_SEC(ERROR, 5)
676           << "NodeDef mentions attribute " << attr.first
677           << " which is not in the op definition: " << SummarizeOpDef(op_def)
678           << " This may be expected if your graph generating binary is newer "
679           << " than this binary. Unknown attributes will be ignored."
680           << " NodeDef: " << FormatNodeDefForError(node_def);
681       continue;
682     }
683 
684     // If attr value is placeholder, do not check it.
685     if (attr.second.placeholder().empty()) {
686       TF_RETURN_WITH_CONTEXT_IF_ERROR(
687           ValidateAttrValue(attr.second, *iter->second),
688           "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
689           SummarizeOpDef(op_def));
690     }
691 
692     // Keep track of which attr names have (not) been found in the NodeDef.
693     op_attrs.erase(iter);
694   }
695 
696   // Were all attrs in the OpDef found in the NodeDef?
697   if (!op_attrs.empty()) {
698     string attrs;
699     for (const auto& attr_pair : op_attrs) {
700       if (!attrs.empty()) strings::StrAppend(&attrs, "', '");
701       strings::StrAppend(&attrs, attr_pair.first);
702     }
703     return errors::InvalidArgument(
704         "NodeDef missing attr", op_attrs.size() == 1 ? " '" : "s '", attrs,
705         "' from ", SummarizeOpDef(op_def),
706         "; NodeDef: ", FormatNodeDefForError(node_def));
707   }
708 
709   // Validate the number of inputs.
710   DataTypeVector inputs, outputs;
711   TF_RETURN_IF_ERROR(InOutTypesForNode(node_def, op_def, &inputs, &outputs));
712 
713   if (num_inputs != inputs.size()) {
714     return errors::InvalidArgument(
715         "NodeDef expected inputs '", DataTypeVectorString(inputs),
716         "' do not match ", num_inputs, " inputs specified; ",
717         SummarizeOpDef(op_def), "; NodeDef: ", FormatNodeDefForError(node_def));
718   }
719 
720   return OkStatus();
721 }
722 
723 namespace {  // Helpers for NameRangesForNode()
724 
ComputeArgRange(const AttrSlice & attrs,const OpDef::ArgDef & arg_def,const OpDef & op_def,int * num)725 Status ComputeArgRange(const AttrSlice& attrs, const OpDef::ArgDef& arg_def,
726                        const OpDef& op_def, int* num) {
727   if (!arg_def.number_attr().empty()) {
728     // Same type repeated "num" times.
729     return GetNodeAttr(attrs, arg_def.number_attr(), num);
730   } else if (!arg_def.type_list_attr().empty()) {
731     const AttrValue* attr_value;
732     TF_RETURN_IF_ERROR(attrs.Find(arg_def.type_list_attr(), &attr_value));
733     *num = attr_value->list().type_size();
734   } else if (!arg_def.type_attr().empty() || arg_def.type() != DT_INVALID) {
735     *num = 1;
736   } else {
737     return errors::InvalidArgument(
738         "Argument '", arg_def.name(),
739         "' incorrectly specified in op definition: ", SummarizeOpDef(op_def));
740   }
741   return OkStatus();
742 }
743 
NameRangesHelper(const AttrSlice & attrs,const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const OpDef & op_def,NameRangeMap * result)744 Status NameRangesHelper(const AttrSlice& attrs,
745                         const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
746                         const OpDef& op_def, NameRangeMap* result) {
747   int start = 0;
748   int num;
749   for (const auto& arg : args) {
750     TF_RETURN_IF_ERROR(ComputeArgRange(attrs, arg, op_def, &num));
751     (*result)[arg.name()] = std::make_pair(start, start + num);
752     start += num;
753   }
754   return OkStatus();
755 }
756 
757 }  // namespace
758 
NameRangesForNode(const AttrSlice & attrs,const OpDef & op_def,NameRangeMap * inputs,NameRangeMap * outputs)759 Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
760                          NameRangeMap* inputs, NameRangeMap* outputs) {
761   if (inputs != nullptr) {
762     TF_RETURN_IF_ERROR(
763         NameRangesHelper(attrs, op_def.input_arg(), op_def, inputs));
764   }
765   if (outputs != nullptr) {
766     return NameRangesHelper(attrs, op_def.output_arg(), op_def, outputs);
767   }
768   return OkStatus();
769 }
770 
AddDefaultsToNodeDef(const OpDef & op_def,NodeDef * node_def)771 void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def) {
772   for (const auto& attr_def : op_def.attr()) {
773     AttrSlice attrs(*node_def);
774     if (attr_def.has_default_value() && !attrs.Find(attr_def.name())) {
775       AddNodeAttr(attr_def.name(), attr_def.default_value(), node_def);
776     }
777   }
778 }
779 
StripDefaultsFromNodeDef(const OpDef & op_def,NodeDef * node_def)780 void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def) {
781   AttrSlice attrs(*node_def);
782   for (const auto& attr_def : op_def.attr()) {
783     if (attr_def.has_default_value()) {
784       const AttrValue* attr = attrs.Find(attr_def.name());
785       if (attr && AreAttrValuesEqual(*attr, attr_def.default_value()))
786         node_def->mutable_attr()->erase(attr_def.name());
787     }
788   }
789 }
790 
791 namespace {
792 
793 using ::tensorflow::tstring;
794 using ::tensorflow::strings::Scanner;
795 
IsValidNodeName(StringPiece sp)796 bool IsValidNodeName(StringPiece sp) {
797   Scanner scanner(sp);
798   scanner.One(Scanner::LETTER_DIGIT_DOT)
799       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
800 
801   while (true) {
802     if (!scanner.GetResult())  // Some error in previous iteration.
803       return false;
804     if (scanner.empty())  // No error, but nothing left, good.
805       return true;
806 
807     // Absorb another name/namespace, starting with a '>'
808     scanner.One(Scanner::RANGLE)
809         .One(Scanner::LETTER_DIGIT_DOT)
810         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
811   }
812 }
813 
IsValidDataInputName(StringPiece sp)814 bool IsValidDataInputName(StringPiece sp) {
815   // Data inputs are op_name, op_name:0, or op_name:12345.
816   Scanner scan(sp);
817   scan.One(Scanner::LETTER_DIGIT_DOT)
818       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
819 
820   while (true) {
821     if (!scan.GetResult())  // Some error in previous iteration.
822       return false;
823     if (scan.empty())  // No error, but nothing left, good.
824       return true;
825 
826     if (scan.Peek() == ':') {  // Absorb identifier after the colon
827       scan.OneLiteral(":");
828       if (scan.Peek() == '0') {
829         scan.OneLiteral("0");  // :0
830       } else {
831         scan.Many(Scanner::DIGIT);  // :[1-9][0-9]*
832       }
833     } else {
834       // Absorb another name/namespace, starting with a '>'
835       scan.One(Scanner::RANGLE)
836           .One(Scanner::LETTER_DIGIT_DOT)
837           .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
838     }
839   }
840 }
841 
IsValidControlInputName(StringPiece sp)842 bool IsValidControlInputName(StringPiece sp) {
843   Scanner scan(sp);
844   scan.OneLiteral("^")
845       .One(Scanner::LETTER_DIGIT_DOT)
846       .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
847 
848   while (true) {
849     if (!scan.GetResult())  // Some error in previous iteration.
850       return false;
851     if (scan.empty())  // No error, but nothing left, good.
852       return true;
853 
854     // Absorb another name/namespace, starting with a '>'
855     scan.One(Scanner::RANGLE)
856         .One(Scanner::LETTER_DIGIT_DOT)
857         .Any(Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
858   }
859 }
860 
861 const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
862 
863 }  // namespace
864 
ValidateOpInput(const string & input_name,bool * is_control_input)865 Status ValidateOpInput(const string& input_name, bool* is_control_input) {
866   *is_control_input = false;
867   if (IsValidDataInputName(input_name)) {
868     return OkStatus();
869   } else if (IsValidControlInputName(input_name)) {
870     *is_control_input = true;
871     return OkStatus();
872   } else {
873     return errors::InvalidArgument("Illegal op input name '", input_name, "'");
874   }
875 }
876 
ValidateNodeName(const string & node_name)877 Status ValidateNodeName(const string& node_name) {
878   if (IsValidNodeName(node_name)) {
879     return OkStatus();
880   } else {
881     return errors::InvalidArgument("Illegal op name '", node_name, "'");
882   }
883 }
884 
ValidateExternalNodeDefSyntax(const NodeDef & node_def)885 Status ValidateExternalNodeDefSyntax(const NodeDef& node_def) {
886   Status s = ValidateNodeName(node_def.name());
887   if (!s.ok()) {
888     return AttachDef(s, node_def);
889   }
890   bool in_control_inputs = false;
891   for (const string& input_name : node_def.input()) {
892     bool is_control_input;
893     s = ValidateOpInput(input_name, &is_control_input);
894     if (!s.ok()) {
895       return AttachDef(s, node_def);
896     }
897 
898     if (in_control_inputs && !is_control_input) {
899       return AttachDef(errors::InvalidArgument(
900                            "All control inputs must follow all data inputs"),
901                        node_def);
902     }
903     in_control_inputs = is_control_input;
904   }
905   return OkStatus();
906 }
907 
AttachDef(const Status & status,const NodeDef & node_def,bool allow_multiple_formatted_node)908 Status AttachDef(const Status& status, const NodeDef& node_def,
909                  bool allow_multiple_formatted_node) {
910   Status ret = status;
911   string node_error;
912   if (!allow_multiple_formatted_node &&
913       status.error_message().find("{{node ") != string::npos) {
914     node_error = node_def.name();
915   } else {
916     node_error = FormatNodeDefForError(node_def);
917   }
918   errors::AppendToMessage(&ret, strings::StrCat(" [[", node_error, "]]"));
919   return ret;
920 }
921 
AddNodeAttr(StringPiece name,const AttrValue & value,NodeDef * node_def)922 void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
923   node_def->mutable_attr()->insert(
924       AttrValueMap::value_type(string(name), value));
925 }
926 
AddNodeAttr(StringPiece name,AttrValue && value,NodeDef * node_def)927 void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def) {
928   (*node_def->mutable_attr())[string(name)] = std::move(value);
929 }
930 
931 #define ADD_NODE_ATTR(T)                                           \
932   void AddNodeAttr(StringPiece name, T value, NodeDef* node_def) { \
933     AttrValue attr_value;                                          \
934     SetAttrValue(value, &attr_value);                              \
935     AddNodeAttr(name, attr_value, node_def);                       \
936   }
937 ADD_NODE_ATTR(StringPiece)
ADD_NODE_ATTR(const char *)938 ADD_NODE_ATTR(const char*)
939 ADD_NODE_ATTR(int32_t)
940 ADD_NODE_ATTR(int64_t)
941 ADD_NODE_ATTR(float)
942 ADD_NODE_ATTR(double)
943 ADD_NODE_ATTR(bool)
944 ADD_NODE_ATTR(DataType)
945 ADD_NODE_ATTR(const PartialTensorShape&)
946 ADD_NODE_ATTR(const Tensor&)
947 ADD_NODE_ATTR(const TensorProto&)
948 ADD_NODE_ATTR(const NameAttrList&)
949 ADD_NODE_ATTR(gtl::ArraySlice<StringPiece>)
950 ADD_NODE_ATTR(gtl::ArraySlice<const char*>)
951 ADD_NODE_ATTR(gtl::ArraySlice<string>)
952 ADD_NODE_ATTR(gtl::ArraySlice<int32>)
953 ADD_NODE_ATTR(gtl::ArraySlice<int64_t>)
954 ADD_NODE_ATTR(gtl::ArraySlice<float>)
955 ADD_NODE_ATTR(gtl::ArraySlice<bool>)
956 ADD_NODE_ATTR(const std::vector<bool>&)
957 ADD_NODE_ATTR(gtl::ArraySlice<DataType>)
958 ADD_NODE_ATTR(gtl::ArraySlice<TensorShape>)
959 ADD_NODE_ATTR(gtl::ArraySlice<PartialTensorShape>)
960 ADD_NODE_ATTR(gtl::ArraySlice<TensorShapeProto>)
961 ADD_NODE_ATTR(gtl::ArraySlice<Tensor>)
962 ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
963 #undef ADD_NODE_ATTR
964 
965 void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
966   map->insert(AttrValueMap::value_type(string(name), value));
967 }
968 
969 #define ADD_ATTR(T)                                            \
970   void AddAttr(StringPiece name, T value, AttrValueMap* map) { \
971     AttrValue attr_value;                                      \
972     SetAttrValue(value, &attr_value);                          \
973     AddAttr(name, attr_value, map);                            \
974   }
ADD_ATTR(bool)975 ADD_ATTR(bool)
976 #undef ADD_ATTR
977 
978 Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
979                                 NodeDef* node_def, bool uniquify_frame_name) {
980   node_def->set_name(strings::StrCat(prefix, node_def->name(), suffix));
981 
982   // Update frame name to avoid multiple LoopCond nodes in one frame.
983   if (uniquify_frame_name &&
984       (node_def->op() == "Enter" || node_def->op() == "RefEnter")) {
985     string frame_name;
986     TF_RETURN_IF_ERROR(GetNodeAttr(*node_def, "frame_name", &frame_name));
987     AttrValue& attr = (*node_def->mutable_attr())["frame_name"];
988     frame_name = strings::StrCat(prefix, frame_name, suffix);
989     attr.set_s(frame_name);
990   }
991 
992   return OkStatus();
993 }
994 
MaybeAddPrefixToColocationConstraints(const std::unordered_set<string> & match,StringPiece prefix,NodeDef * node_def)995 Status MaybeAddPrefixToColocationConstraints(
996     const std::unordered_set<string>& match, StringPiece prefix,
997     NodeDef* node_def) {
998   auto attr = node_def->mutable_attr()->find(kColocationAttrName);
999   if (attr == node_def->mutable_attr()->end()) {
1000     return OkStatus();
1001   }
1002   auto constraints_list = attr->second.mutable_list();
1003   auto constraints_size = constraints_list->s_size();
1004   for (size_t i = 0; i < constraints_size; ++i) {
1005     StringPiece original(constraints_list->s(i));
1006     if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
1007       if (match.find(string(original)) != match.end()) {
1008         (*constraints_list->mutable_s(i)) =
1009             strings::StrCat(kColocationGroupPrefix, prefix, original);
1010       }
1011     }
1012   }
1013   return OkStatus();
1014 }
1015 
MaybeUpdateColocationConstraintsWithMap(const std::map<absl::string_view,absl::string_view> & node_name_map,NodeDef * node_def)1016 Status MaybeUpdateColocationConstraintsWithMap(
1017     const std::map<absl::string_view, absl::string_view>& node_name_map,
1018     NodeDef* node_def) {
1019   auto attr = node_def->mutable_attr()->find(kColocationAttrName);
1020   if (attr == node_def->mutable_attr()->end()) {
1021     return OkStatus();
1022   }
1023   auto constraints_list = attr->second.mutable_list();
1024   auto constraints_size = constraints_list->s_size();
1025   for (size_t i = 0; i < constraints_size; ++i) {
1026     StringPiece original(constraints_list->s(i));
1027     if (absl::ConsumePrefix(&original, kColocationGroupPrefixStringPiece)) {
1028       if (node_name_map.find(original) != node_name_map.end()) {
1029         (*constraints_list->mutable_s(i)) =
1030             strings::StrCat(kColocationGroupPrefix, node_name_map.at(original));
1031       }
1032     }
1033   }
1034   return OkStatus();
1035 }
1036 
1037 }  // namespace tensorflow