• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/hash/hash.h"
30 #include "tensorflow/core/lib/strings/proto_serialization.h"
31 #include "tensorflow/core/lib/strings/scanner.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 namespace {  // ------ Helper functions ------
40 
HasAttrStyleType(const OpDef::ArgDef & arg)41 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
42   return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
43          !arg.type_list_attr().empty();
44 }
45 
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)46 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
47   const AttrValue& allowed_values(attr.allowed_values());
48   for (auto allowed : allowed_values.list().type()) {
49     if (dt == allowed) {
50       return OkStatus();
51     }
52   }
53   string allowed_str;
54   for (int i = 0; i < allowed_values.list().type_size(); ++i) {
55     if (!allowed_str.empty()) {
56       strings::StrAppend(&allowed_str, ", ");
57     }
58     strings::StrAppend(&allowed_str,
59                        DataTypeString(allowed_values.list().type(i)));
60   }
61   return errors::InvalidArgument(
62       "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
63       " is not in the list of allowed values: ", allowed_str);
64 }
65 
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)66 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
67   const AttrValue& allowed_values(attr.allowed_values());
68   for (const auto& allowed : allowed_values.list().s()) {
69     if (str == allowed) {
70       return OkStatus();
71     }
72   }
73   string allowed_str;
74   for (const string& allowed : allowed_values.list().s()) {
75     if (!allowed_str.empty()) {
76       strings::StrAppend(&allowed_str, ", ");
77     }
78     strings::StrAppend(&allowed_str, "\"", allowed, "\"");
79   }
80   return errors::InvalidArgument(
81       "Value for attr '", attr.name(), "' of \"", str,
82       "\" is not in the list of allowed values: ", allowed_str);
83 }
84 
85 }  // namespace
86 
87 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)88 Status ValidateAttrValue(const AttrValue& attr_value,
89                          const OpDef::AttrDef& attr) {
90   // Is it a valid value?
91   TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
92                                   " for attr '", attr.name(), "'");
93 
94   // Does the value satisfy the minimum constraint in the AttrDef?
95   if (attr.has_minimum()) {
96     if (attr.type() == "int") {
97       if (attr_value.i() < attr.minimum()) {
98         return errors::InvalidArgument(
99             "Value for attr '", attr.name(), "' of ", attr_value.i(),
100             " must be at least minimum ", attr.minimum());
101       }
102     } else {
103       int length = -1;
104       if (attr.type() == "list(string)") {
105         length = attr_value.list().s_size();
106       } else if (attr.type() == "list(int)") {
107         length = attr_value.list().i_size();
108       } else if (attr.type() == "list(float)") {
109         length = attr_value.list().f_size();
110       } else if (attr.type() == "list(bool)") {
111         length = attr_value.list().b_size();
112       } else if (attr.type() == "list(type)") {
113         length = attr_value.list().type_size();
114       } else if (attr.type() == "list(shape)") {
115         length = attr_value.list().shape_size();
116       } else if (attr.type() == "list(tensor)") {
117         length = attr_value.list().tensor_size();
118       } else if (attr.type() == "list(func)") {
119         length = attr_value.list().func_size();
120       }
121       if (length < attr.minimum()) {
122         return errors::InvalidArgument(
123             "Length for attr '", attr.name(), "' of ", length,
124             " must be at least minimum ", attr.minimum());
125       }
126     }
127   }
128 
129   // Does the value satisfy the allowed_value constraint in the AttrDef?
130   if (attr.has_allowed_values()) {
131     if (attr.type() == "type") {
132       TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
133     } else if (attr.type() == "list(type)") {
134       for (int dt : attr_value.list().type()) {
135         TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
136       }
137     } else if (attr.type() == "string") {
138       TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
139     } else if (attr.type() == "list(string)") {
140       for (const string& str : attr_value.list().s()) {
141         TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
142       }
143     } else {
144       return errors::Unimplemented(
145           "Support for allowed_values not implemented for type ", attr.type());
146     }
147   }
148   return OkStatus();
149 }
150 
FindAttr(StringPiece name,const OpDef & op_def)151 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
152   for (int i = 0; i < op_def.attr_size(); ++i) {
153     if (op_def.attr(i).name() == name) {
154       return &op_def.attr(i);
155     }
156   }
157   return nullptr;
158 }
159 
FindAttrMutable(StringPiece name,OpDef * op_def)160 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
161   for (int i = 0; i < op_def->attr_size(); ++i) {
162     if (op_def->attr(i).name() == name) {
163       return op_def->mutable_attr(i);
164     }
165   }
166   return nullptr;
167 }
168 
FindInputArg(StringPiece name,const OpDef & op_def)169 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
170   for (int i = 0; i < op_def.input_arg_size(); ++i) {
171     if (op_def.input_arg(i).name() == name) {
172       return &op_def.input_arg(i);
173     }
174   }
175   return nullptr;
176 }
177 
FindInputArg(StringPiece name,const ApiDef & api_def)178 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
179   for (int i = 0; i < api_def.in_arg_size(); ++i) {
180     if (api_def.in_arg(i).name() == name) {
181       return &api_def.in_arg(i);
182     }
183   }
184   return nullptr;
185 }
186 
187 #define VALIDATE(EXPR, ...)                                        \
188   do {                                                             \
189     if (!(EXPR)) {                                                 \
190       return errors::InvalidArgument(                              \
191           __VA_ARGS__, "; in OpDef: ", op_def.ShortDebugString()); \
192     }                                                              \
193   } while (false)
194 
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)195 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
196                           bool output, std::set<string>* names) {
197   const string suffix = strings::StrCat(
198       output ? " for output '" : " for input '", arg.name(), "'");
199   VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
200            "Duplicate name: ", arg.name());
201   VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
202 
203   if (!arg.number_attr().empty()) {
204     const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
205     VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
206              suffix);
207     VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
208              suffix, " has type ", attr->type(), " != int");
209     VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
210              suffix, " must have minimum");
211     VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
212              suffix, " must have minimum >= 0");
213     VALIDATE(arg.type_list_attr().empty(),
214              "Can't have both number_attr and type_list_attr", suffix);
215     VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
216                      (!arg.type_attr().empty() ? 1 : 0) ==
217                  1,
218              "Exactly one of type, type_attr must be set", suffix);
219   } else {
220     const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
221                                 (!arg.type_attr().empty() ? 1 : 0) +
222                                 (!arg.type_list_attr().empty() ? 1 : 0);
223     VALIDATE(num_type_fields == 1,
224              "Exactly one of type, type_attr, type_list_attr must be set",
225              suffix);
226   }
227 
228   if (!arg.type_attr().empty()) {
229     const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
230     VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
231              suffix);
232     VALIDATE(attr->type() == "type", "Attr '", attr->name(),
233              "' used as type_attr", suffix, " has type ", attr->type(),
234              " != type");
235   } else if (!arg.type_list_attr().empty()) {
236     const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
237     VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
238              suffix);
239     VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
240              "' used as type_list_attr", suffix, " has type ", attr->type(),
241              " != list(type)");
242   } else {
243     // All argument types should be non-reference types at this point.
244     // ArgDef.is_ref is set to true for reference arguments.
245     VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
246              DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
247   }
248 
249   return OkStatus();
250 }
251 
IsValidOpName(StringPiece sp)252 bool IsValidOpName(StringPiece sp) {
253   using ::tensorflow::strings::Scanner;
254 
255   Scanner scanner(sp);
256   scanner.One(Scanner::UPPERLETTER).Any(Scanner::LETTER_DIGIT_UNDERSCORE);
257 
258   while (true) {
259     if (!scanner.GetResult())  // Some error in previous iteration.
260       return false;
261     if (scanner.empty())  // No error, but nothing left, good.
262       return true;
263 
264     // Absorb another name/namespace, starting with a '>'
265     scanner.One(Scanner::RANGLE)
266         .One(Scanner::UPPERLETTER)
267         .Any(Scanner::LETTER_DIGIT_UNDERSCORE);
268   }
269 }
270 
ValidateOpDef(const OpDef & op_def)271 Status ValidateOpDef(const OpDef& op_def) {
272   if (!absl::StartsWith(op_def.name(), "_")) {
273     VALIDATE(IsValidOpName(op_def.name()), "Invalid name: ", op_def.name(),
274              " (Did you use CamelCase?)");
275   }
276 
277   std::set<string> names;  // for detecting duplicate names
278   for (const auto& attr : op_def.attr()) {
279     // Validate name
280     VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
281              "Duplicate name: ", attr.name());
282     DataType dt;
283     VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
284              attr.name(), " that matches a data type");
285 
286     // Validate type
287     StringPiece type(attr.type());
288     bool is_list = absl::ConsumePrefix(&type, "list(");
289     bool found = false;
290     for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
291                               "tensor", "func"}) {
292       if (absl::ConsumePrefix(&type, valid)) {
293         found = true;
294         break;
295       }
296     }
297     VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
298              "'");
299     if (is_list) {
300       VALIDATE(absl::ConsumePrefix(&type, ")"),
301                "'list(' is missing ')' in attr ", attr.name(), "'s type ",
302                attr.type());
303     }
304     VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
305              attr.name(), "'s type ", attr.type());
306 
307     // Validate minimum
308     if (attr.has_minimum()) {
309       VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
310                "' has minimum for unsupported type ", attr.type());
311       if (is_list) {
312         VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
313                  "' with list type must have a non-negative minimum, not ",
314                  attr.minimum());
315       }
316     } else {
317       VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
318                "' with has_minimum = false but minimum ", attr.minimum(),
319                " not equal to default of 0");
320     }
321 
322     // Validate allowed_values
323     if (attr.has_allowed_values()) {
324       const string list_type =
325           is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
326       TF_RETURN_WITH_CONTEXT_IF_ERROR(
327           AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
328           attr.name(), "' in Op '", op_def.name(), "'");
329     }
330 
331     // Validate default_value (after we have validated the rest of the attr,
332     // so we can use ValidateAttrValue()).
333     if (attr.has_default_value()) {
334       TF_RETURN_WITH_CONTEXT_IF_ERROR(
335           ValidateAttrValue(attr.default_value(), attr), " in Op '",
336           op_def.name(), "'");
337     }
338   }
339 
340   for (const auto& arg : op_def.input_arg()) {
341     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
342   }
343 
344   for (const auto& arg : op_def.output_arg()) {
345     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
346   }
347 
348   return OkStatus();
349 }
350 
351 #undef VALIDATE
352 
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)353 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
354   if (op_def.has_deprecation()) {
355     const OpDeprecation& dep = op_def.deprecation();
356     if (graph_def_version >= dep.version()) {
357       return errors::Unimplemented(
358           "Op ", op_def.name(), " is not available in GraphDef version ",
359           graph_def_version, ". It has been removed in version ", dep.version(),
360           ". ", dep.explanation(), ".");
361     } else {
362       // Warn only once for each op name, and do it in a threadsafe manner.
363       static mutex mu(LINKER_INITIALIZED);
364       static std::unordered_set<string> warned;
365       bool warn;
366       {
367         mutex_lock lock(mu);
368         warn = warned.insert(op_def.name()).second;
369       }
370       if (warn) {
371         LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
372                      << " It will cease to work in GraphDef version "
373                      << dep.version() << ". " << dep.explanation() << ".";
374       }
375     }
376   }
377   return OkStatus();
378 }
379 
380 namespace {
381 
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)382 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
383   string ret;
384   for (const OpDef::ArgDef& arg : args) {
385     if (!ret.empty()) strings::StrAppend(&ret, ", ");
386     strings::StrAppend(&ret, arg.name(), ":");
387     if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
388     if (!arg.number_attr().empty()) {
389       strings::StrAppend(&ret, arg.number_attr(), "*");
390     }
391     if (arg.type() != DT_INVALID) {
392       strings::StrAppend(&ret, DataTypeString(arg.type()));
393     } else {
394       strings::StrAppend(&ret, arg.type_attr());
395     }
396     if (arg.is_ref()) strings::StrAppend(&ret, ")");
397   }
398   return ret;
399 }
400 
401 }  // namespace
402 
SummarizeOpDef(const OpDef & op_def)403 string SummarizeOpDef(const OpDef& op_def) {
404   string ret = strings::StrCat("Op<name=", op_def.name());
405   strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
406                      " -> ", SummarizeArgs(op_def.output_arg()));
407   for (int i = 0; i < op_def.attr_size(); ++i) {
408     strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
409                        op_def.attr(i).type());
410     if (op_def.attr(i).has_default_value()) {
411       strings::StrAppend(&ret, ",default=",
412                          SummarizeAttrValue(op_def.attr(i).default_value()));
413     }
414     if (op_def.attr(i).has_minimum()) {
415       strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
416     }
417     if (op_def.attr(i).has_allowed_values()) {
418       strings::StrAppend(&ret, ",allowed=",
419                          SummarizeAttrValue(op_def.attr(i).allowed_values()));
420     }
421   }
422   if (op_def.is_commutative()) {
423     strings::StrAppend(&ret, "; is_commutative=true");
424   }
425   if (op_def.is_aggregate()) {
426     strings::StrAppend(&ret, "; is_aggregate=true");
427   }
428   if (op_def.is_stateful()) {
429     strings::StrAppend(&ret, "; is_stateful=true");
430   }
431   if (op_def.allows_uninitialized_input()) {
432     strings::StrAppend(&ret, "; allows_uninitialized_input=true");
433   }
434   if (op_def.is_distributed_communication()) {
435     strings::StrAppend(&ret, "; is_distributed_communication=true");
436   }
437   strings::StrAppend(&ret, ">");
438   return ret;
439 }
440 
441 namespace {
442 
443 // Returns true if every element of `sub` is contained in `super`.
444 template <class T>
IsSubsetOf(const T & sub,const T & super)445 bool IsSubsetOf(const T& sub, const T& super) {
446   for (const auto& o : sub) {
447     bool found = false;
448     for (const auto& n : super) {
449       if (o == n) {
450         found = true;
451         break;
452       }
453     }
454     if (!found) return false;
455   }
456   return true;
457 }
458 
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)459 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
460                      const OpDef::AttrDef& new_attr) {
461   // Anything -> no restriction : not more restrictive.
462   if (!new_attr.has_allowed_values()) return false;
463   // No restriction -> restriction : more restrictive.
464   if (!old_attr.has_allowed_values()) return true;
465   // If anything that was previously allowed is no longer allowed:
466   // more restrictive.
467   if (!IsSubsetOf(old_attr.allowed_values().list().type(),
468                   new_attr.allowed_values().list().type())) {
469     return true;
470   }
471   if (!IsSubsetOf(old_attr.allowed_values().list().s(),
472                   new_attr.allowed_values().list().s())) {
473     return true;
474   }
475   return false;
476 }
477 
AllowedStr(const OpDef::AttrDef & attr)478 string AllowedStr(const OpDef::AttrDef& attr) {
479   if (!attr.has_allowed_values()) return "no restriction";
480   return SummarizeAttrValue(attr.allowed_values());
481 }
482 
DefaultAttrStr(const OpDef::AttrDef & attr)483 string DefaultAttrStr(const OpDef::AttrDef& attr) {
484   if (!attr.has_default_value()) return "no default";
485   return SummarizeAttrValue(attr.default_value());
486 }
487 
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)488 bool HigherMinimum(const OpDef::AttrDef& old_attr,
489                    const OpDef::AttrDef& new_attr) {
490   // Anything -> no restriction : not more restrictive.
491   if (!new_attr.has_minimum()) return false;
492   // No restriction -> restriction : more restrictive.
493   if (!old_attr.has_minimum()) return true;
494   // If anything that was previously allowed is no longer allowed:
495   // more restrictive.
496   return new_attr.minimum() > old_attr.minimum();
497 }
498 
MinStr(const OpDef::AttrDef & attr)499 string MinStr(const OpDef::AttrDef& attr) {
500   if (!attr.has_minimum()) return "no minimum";
501   return strings::StrCat(attr.minimum());
502 }
503 
504 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)505 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
506   for (const auto& attr : op_def.attr()) {
507     (*attr_map)[attr.name()] = &attr;
508   }
509 }
510 
511 // Add a comma to *s every call but the first (*add_comma should be
512 // initialized to false).
AddComma(string * s,bool * add_comma)513 void AddComma(string* s, bool* add_comma) {
514   if (*add_comma) {
515     strings::StrAppend(s, ", ");
516   } else {
517     *add_comma = true;
518   }
519 }
520 
521 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)522 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
523   if (name) {
524     strings::StrAppend(s, arg.name(), ":");
525   }
526 }
527 
528 // Compute a signature for either inputs or outputs that will be the
529 // same for both the old and new OpDef if they are compatible.  We
530 // assume that new_attrs is a superset of old_attrs, and that any attr
531 // in the difference has a default.  Our strategy is to make a list of
532 // types, where the types are things like:
533 // * "int32", "float", etc.,
534 // * "T" for some attr "T" in old_attrs, or
535 // * "N * type" for "N" either some attr in old_attrs.
536 //
537 // We get the types by either using the attrs in args if they are in
538 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)539 string ComputeArgSignature(
540     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
541     const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
542     bool names) {
543   string s;
544   bool add_comma = false;
545   for (const OpDef::ArgDef& arg : args) {
546     if (!arg.type_list_attr().empty()) {
547       const OpDef::AttrDef* old_attr =
548           gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
549       if (old_attr) {
550         // Both old and new have the list(type) attr, so can use it directly.
551         AddComma(&s, &add_comma);
552         AddName(&s, names, arg);
553         strings::StrAppend(&s, arg.type_list_attr());
554         ref->push_back(arg.is_ref());
555       } else {
556         // Missing the list(type) attr in the old, so use the default
557         // value for the attr from new instead.
558         const OpDef::AttrDef* new_attr =
559             gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
560         const auto& type_list = new_attr->default_value().list().type();
561         if (type_list.empty()) continue;
562         for (int i = 0; i < type_list.size(); ++i) {
563           AddComma(&s, &add_comma);
564           AddName(&s, names, arg);
565           strings::StrAppend(
566               &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
567           ref->push_back(arg.is_ref());
568         }
569       }
570     } else {
571       int num = 1;  // How many input/outputs does this represent?
572       string type;  // What is the type of this arg?
573       AddName(&type, names, arg);
574       if (!arg.number_attr().empty()) {
575         // N * type case.
576         const OpDef::AttrDef* old_attr =
577             gtl::FindPtrOrNull(old_attrs, arg.number_attr());
578         if (old_attr) {
579           // Both old and new have the number attr, so can use it directly.
580           strings::StrAppend(&type, arg.number_attr(), " * ");
581         } else {
582           // Missing the number attr in the old, so use the default
583           // value for the attr from new instead.
584           const OpDef::AttrDef* new_attr =
585               gtl::FindPtrOrNull(new_attrs, arg.number_attr());
586           num = new_attr->default_value().i();
587         }
588       }
589 
590       if (arg.type() != DT_INVALID) {
591         // int32, float, etc. case
592         strings::StrAppend(&type, DataTypeString(arg.type()));
593       } else {
594         const OpDef::AttrDef* old_attr =
595             gtl::FindPtrOrNull(old_attrs, arg.type_attr());
596         if (old_attr) {
597           // Both old and new have the type attr, so can use it directly.
598           strings::StrAppend(&type, arg.type_attr());
599         } else {
600           // Missing the type attr in the old, so use the default
601           // value for the attr from new instead.
602           const OpDef::AttrDef* new_attr =
603               gtl::FindPtrOrNull(new_attrs, arg.type_attr());
604           strings::StrAppend(&type,
605                              DataTypeString(new_attr->default_value().type()));
606         }
607       }
608 
609       // Record `num` * `type` in the signature.
610       for (int i = 0; i < num; ++i) {
611         AddComma(&s, &add_comma);
612         strings::StrAppend(&s, type);
613         ref->push_back(arg.is_ref());
614       }
615     }
616   }
617 
618   return s;
619 }
620 
621 }  // namespace
622 
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)623 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
624 #define VALIDATE(CONDITION, ...)                                            \
625   if (!(CONDITION)) {                                                       \
626     return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
627                                    "; old: ", SummarizeOpDef(old_op),       \
628                                    "; new: ", SummarizeOpDef(new_op));      \
629   }
630 
631   VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
632 
633   AttrMap new_attrs, old_attrs;
634   FillAttrMap(old_op, &old_attrs);
635   FillAttrMap(new_op, &new_attrs);
636   for (const auto& old_attr : old_op.attr()) {
637     const OpDef::AttrDef* new_attr =
638         gtl::FindPtrOrNull(new_attrs, old_attr.name());
639     VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
640     VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
641              "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
642              "'");
643     VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
644              "' has a stricter set of allowed values; from ",
645              AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
646     VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
647              "' has a higher minimum; from ", MinStr(old_attr), " to ",
648              MinStr(*new_attr));
649   }
650 
651   for (const auto& new_attr : new_op.attr()) {
652     const OpDef::AttrDef* old_attr =
653         gtl::FindPtrOrNull(old_attrs, new_attr.name());
654     VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
655              new_attr.name(), "' added without default");
656   }
657 
658   std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
659   const string old_in_sig = ComputeArgSignature(
660       old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
661   const string new_in_sig = ComputeArgSignature(
662       new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
663   VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
664            "' vs. '", new_in_sig, "'");
665   VALIDATE(old_in_ref.size() == new_in_ref.size(),  // Should not happen
666            "Unexpected change in input ref lists.");
667   for (int i = 0, end = old_in_ref.size(); i < end; ++i) {
668     // Allowed to remove "ref" from an input (or leave it unchanged).
669     VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
670              " changed from non-ref to ref");
671   }
672 
673   const string old_out_sig =
674       ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
675                           &old_out_ref, true /* names */);
676   const string new_out_sig =
677       ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
678                           &new_out_ref, true /* names */);
679   VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
680            old_out_sig, "' vs. '", new_out_sig, "'");
681   VALIDATE(old_out_ref.size() == new_out_ref.size(),  // Should not happen
682            "Unexpected change in output ref lists");
683   for (int i = 0, end = old_out_ref.size(); i < end; ++i) {
684     // Allowed to add "ref" to an output (or leave it unchanged).
685     VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
686              " changed from ref to non-ref");
687   }
688 
689   return OkStatus();
690 }
691 
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)692 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
693                                    const OpDef& penultimate_op,
694                                    const OpDef& new_op) {
695   AttrMap new_attrs, old_attrs;
696   FillAttrMap(old_op, &old_attrs);
697   FillAttrMap(new_op, &new_attrs);
698 
699   for (const auto& penultimate_attr : penultimate_op.attr()) {
700     const OpDef::AttrDef* old_attr =
701         gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
702     if (old_attr != nullptr) continue;  // attr wasn't added
703     const OpDef::AttrDef* new_attr =
704         gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
705 
706     // These shouldn't happen if the op passed OpDefCompatible().
707     if (new_attr == nullptr) {
708       return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
709                                      "' in op: ", SummarizeOpDef(new_op));
710     }
711     if (!penultimate_attr.has_default_value() ||
712         !new_attr->has_default_value()) {
713       return errors::InvalidArgument("Missing default for attr '",
714                                      penultimate_attr.name(),
715                                      "' in op: ", SummarizeOpDef(new_op));
716     }
717 
718     // Actually test that the attr's default value hasn't changed.
719     if (!AreAttrValuesEqual(penultimate_attr.default_value(),
720                             new_attr->default_value())) {
721       return errors::InvalidArgument(
722           "Can't change default value for attr '", penultimate_attr.name(),
723           "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
724           " in op: ", SummarizeOpDef(new_op));
725     }
726   }
727 
728   return OkStatus();
729 }
730 
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)731 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
732   AttrMap new_attrs, old_attrs;
733   FillAttrMap(old_op, &old_attrs);
734   FillAttrMap(new_op, &new_attrs);
735 
736   for (const auto& old_attr : old_op.attr()) {
737     const OpDef::AttrDef* new_attr =
738         gtl::FindPtrOrNull(new_attrs, old_attr.name());
739     if (new_attr == nullptr) continue;
740     if (new_attr->has_default_value() && !old_attr.has_default_value()) {
741       continue;  // Adding new default values is safe.
742     }
743     if (old_attr.has_default_value() && !new_attr->has_default_value()) {
744       return errors::InvalidArgument(
745           "Attr '", old_attr.name(), "' has removed it's default; ", "from ",
746           DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
747     }
748     if (old_attr.has_default_value() &&
749         !AreAttrValuesEqual(old_attr.default_value(),
750                             new_attr->default_value())) {
751       return errors::InvalidArgument(
752           "Attr '", old_attr.name(), "' has changed it's default value; ",
753           "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
754     }
755   }
756 
757   return OkStatus();
758 }
759 
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)760 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
761   for (int i = 0; i < op_def->input_arg_size(); ++i) {
762     op_def->mutable_input_arg(i)->clear_description();
763   }
764   for (int i = 0; i < op_def->output_arg_size(); ++i) {
765     op_def->mutable_output_arg(i)->clear_description();
766   }
767   for (int i = 0; i < op_def->attr_size(); ++i) {
768     op_def->mutable_attr(i)->clear_description();
769   }
770   op_def->clear_summary();
771   op_def->clear_description();
772 }
773 
RemoveDescriptionsFromOpDef(OpDef * op_def)774 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
775   RemoveNonDeprecationDescriptionsFromOpDef(op_def);
776   if (op_def->has_deprecation()) {
777     op_def->mutable_deprecation()->clear_explanation();
778   }
779 }
780 
RemoveDescriptionsFromOpList(OpList * op_list)781 void RemoveDescriptionsFromOpList(OpList* op_list) {
782   for (int i = 0; i < op_list->op_size(); ++i) {
783     OpDef* op_def = op_list->mutable_op(i);
784     RemoveDescriptionsFromOpDef(op_def);
785   }
786 }
787 
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)788 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
789   if (std::is_base_of<protobuf::Message, OpDef::AttrDef>()) {
790     DCHECK_EQ(7, reinterpret_cast<const protobuf::Message*>(&a1)
791                      ->GetDescriptor()
792                      ->field_count())
793         << "Please modify these equality and hash functions to reflect the "
794            "changes to the AttrDef protobuf";
795   }
796 
797   if (a1.name() != a2.name()) return false;
798   if (a1.type() != a2.type()) return false;
799   if (a1.description() != a2.description()) return false;
800   if (a1.has_minimum() != a2.has_minimum()) return false;
801   if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
802   if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
803   if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
804     return false;
805   return true;
806 }
807 
AttrDefHash(const OpDef::AttrDef & a)808 uint64 AttrDefHash(const OpDef::AttrDef& a) {
809   uint64 h = Hash64(a.name());
810   h = Hash64(a.type().data(), a.type().size(), h);
811   h = Hash64Combine(AttrValueHash(a.default_value()), h);
812   h = Hash64(a.description().data(), a.description().size(), h);
813   h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
814   h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
815   h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
816   return h;
817 }
818 
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)819 bool RepeatedAttrDefEqual(
820     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
821     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
822   std::unordered_map<string, const OpDef::AttrDef*> a1_set;
823   for (const OpDef::AttrDef& def : a1) {
824     if (a1_set.find(def.name()) != a1_set.end()) {
825       LOG(ERROR) << "AttrDef names must be unique, but '" << def.name()
826                  << "' appears more than once";
827     }
828     a1_set[def.name()] = &def;
829   }
830   for (const OpDef::AttrDef& def : a2) {
831     auto iter = a1_set.find(def.name());
832     if (iter == a1_set.end()) return false;
833     if (!AttrDefEqual(*iter->second, def)) return false;
834     a1_set.erase(iter);
835   }
836   if (!a1_set.empty()) return false;
837   return true;
838 }
839 
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)840 uint64 RepeatedAttrDefHash(
841     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
842   // Insert AttrDefs into map to deterministically sort by name
843   std::map<string, const OpDef::AttrDef*> a_set;
844   for (const OpDef::AttrDef& def : a) {
845     a_set[def.name()] = &def;
846   }
847   // Iterate and combines hashes of keys and values
848   uint64 h = 0xDECAFCAFFE;
849   for (const auto& pair : a_set) {
850     h = Hash64(pair.first.data(), pair.first.size(), h);
851     h = Hash64Combine(AttrDefHash(*pair.second), h);
852   }
853   return h;
854 }
855 
OpDefEqual(const OpDef & o1,const OpDef & o2)856 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
857   // attr order doesn't matter.
858   // Compare it separately here instead of serializing below.
859   if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
860 
861   // `control_output` order doesn't matter.
862   std::set<string> control_output1(o1.control_output().begin(),
863                                    o1.control_output().end());
864   std::set<string> control_output2(o2.control_output().begin(),
865                                    o2.control_output().end());
866   if (control_output1 != control_output2) return false;
867 
868   // Clear `attr` and `control_output` fields, serialize, and compare serialized
869   // strings.
870   OpDef o1_copy = o1;
871   OpDef o2_copy = o2;
872   o1_copy.clear_attr();
873   o1_copy.clear_control_output();
874   o2_copy.clear_attr();
875   o2_copy.clear_control_output();
876 
877   return AreSerializedProtosEqual(o1_copy, o2_copy);
878 }
879 
OpDefHash(const OpDef & o)880 uint64 OpDefHash(const OpDef& o) {
881   uint64 h = RepeatedAttrDefHash(o.attr());
882 
883   // Compute deterministic order-independent control outputs hash.
884   std::set<string> control_output(o.control_output().begin(),
885                                   o.control_output().end());
886   for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co));
887 
888   OpDef o_copy = o;
889   o_copy.clear_attr();
890   o_copy.clear_control_output();
891   return DeterministicProtoHash64(o_copy, h);
892 }
893 
894 }  // namespace tensorflow
895