• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/hash/hash.h"
30 #include "tensorflow/core/lib/strings/proto_serialization.h"
31 #include "tensorflow/core/lib/strings/scanner.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace tensorflow {
39 namespace {  // ------ Helper functions ------
40 
HasAttrStyleType(const OpDef::ArgDef & arg)41 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
42   return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
43          !arg.type_list_attr().empty();
44 }
45 
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)46 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
47   const AttrValue& allowed_values(attr.allowed_values());
48   for (auto allowed : allowed_values.list().type()) {
49     if (dt == allowed) {
50       return Status::OK();
51     }
52   }
53   string allowed_str;
54   for (int i = 0; i < allowed_values.list().type_size(); ++i) {
55     if (!allowed_str.empty()) {
56       strings::StrAppend(&allowed_str, ", ");
57     }
58     strings::StrAppend(&allowed_str,
59                        DataTypeString(allowed_values.list().type(i)));
60   }
61   return errors::InvalidArgument(
62       "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
63       " is not in the list of allowed values: ", allowed_str);
64 }
65 
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)66 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
67   const AttrValue& allowed_values(attr.allowed_values());
68   for (const auto& allowed : allowed_values.list().s()) {
69     if (str == allowed) {
70       return Status::OK();
71     }
72   }
73   string allowed_str;
74   for (const string& allowed : allowed_values.list().s()) {
75     if (!allowed_str.empty()) {
76       strings::StrAppend(&allowed_str, ", ");
77     }
78     strings::StrAppend(&allowed_str, "\"", allowed, "\"");
79   }
80   return errors::InvalidArgument(
81       "Value for attr '", attr.name(), "' of \"", str,
82       "\" is not in the list of allowed values: ", allowed_str);
83 }
84 
85 }  // namespace
86 
87 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)88 Status ValidateAttrValue(const AttrValue& attr_value,
89                          const OpDef::AttrDef& attr) {
90   // Is it a valid value?
91   TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
92                                   " for attr '", attr.name(), "'");
93 
94   // Does the value satisfy the minimum constraint in the AttrDef?
95   if (attr.has_minimum()) {
96     if (attr.type() == "int") {
97       if (attr_value.i() < attr.minimum()) {
98         return errors::InvalidArgument(
99             "Value for attr '", attr.name(), "' of ", attr_value.i(),
100             " must be at least minimum ", attr.minimum());
101       }
102     } else {
103       int length = -1;
104       if (attr.type() == "list(string)") {
105         length = attr_value.list().s_size();
106       } else if (attr.type() == "list(int)") {
107         length = attr_value.list().i_size();
108       } else if (attr.type() == "list(float)") {
109         length = attr_value.list().f_size();
110       } else if (attr.type() == "list(bool)") {
111         length = attr_value.list().b_size();
112       } else if (attr.type() == "list(type)") {
113         length = attr_value.list().type_size();
114       } else if (attr.type() == "list(shape)") {
115         length = attr_value.list().shape_size();
116       } else if (attr.type() == "list(tensor)") {
117         length = attr_value.list().tensor_size();
118       } else if (attr.type() == "list(func)") {
119         length = attr_value.list().func_size();
120       }
121       if (length < attr.minimum()) {
122         return errors::InvalidArgument(
123             "Length for attr '", attr.name(), "' of ", length,
124             " must be at least minimum ", attr.minimum());
125       }
126     }
127   }
128 
129   // Does the value satisfy the allowed_value constraint in the AttrDef?
130   if (attr.has_allowed_values()) {
131     if (attr.type() == "type") {
132       TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
133     } else if (attr.type() == "list(type)") {
134       for (int dt : attr_value.list().type()) {
135         TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
136       }
137     } else if (attr.type() == "string") {
138       TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
139     } else if (attr.type() == "list(string)") {
140       for (const string& str : attr_value.list().s()) {
141         TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
142       }
143     } else {
144       return errors::Unimplemented(
145           "Support for allowed_values not implemented for type ", attr.type());
146     }
147   }
148   return Status::OK();
149 }
150 
FindAttr(StringPiece name,const OpDef & op_def)151 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
152   for (int i = 0; i < op_def.attr_size(); ++i) {
153     if (op_def.attr(i).name() == name) {
154       return &op_def.attr(i);
155     }
156   }
157   return nullptr;
158 }
159 
FindAttrMutable(StringPiece name,OpDef * op_def)160 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
161   for (int i = 0; i < op_def->attr_size(); ++i) {
162     if (op_def->attr(i).name() == name) {
163       return op_def->mutable_attr(i);
164     }
165   }
166   return nullptr;
167 }
168 
FindInputArg(StringPiece name,const OpDef & op_def)169 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
170   for (int i = 0; i < op_def.input_arg_size(); ++i) {
171     if (op_def.input_arg(i).name() == name) {
172       return &op_def.input_arg(i);
173     }
174   }
175   return nullptr;
176 }
177 
FindInputArg(StringPiece name,const ApiDef & api_def)178 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
179   for (int i = 0; i < api_def.in_arg_size(); ++i) {
180     if (api_def.in_arg(i).name() == name) {
181       return &api_def.in_arg(i);
182     }
183   }
184   return nullptr;
185 }
186 
187 #define VALIDATE(EXPR, ...)                                        \
188   do {                                                             \
189     if (!(EXPR)) {                                                 \
190       return errors::InvalidArgument(                              \
191           __VA_ARGS__, "; in OpDef: ", op_def.ShortDebugString()); \
192     }                                                              \
193   } while (false)
194 
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)195 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
196                           bool output, std::set<string>* names) {
197   const string suffix = strings::StrCat(
198       output ? " for output '" : " for input '", arg.name(), "'");
199   VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
200            "Duplicate name: ", arg.name());
201   VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
202 
203   if (!arg.number_attr().empty()) {
204     const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
205     VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
206              suffix);
207     VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
208              suffix, " has type ", attr->type(), " != int");
209     VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
210              suffix, " must have minimum");
211     VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
212              suffix, " must have minimum >= 0");
213     VALIDATE(arg.type_list_attr().empty(),
214              "Can't have both number_attr and type_list_attr", suffix);
215     VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
216                      (!arg.type_attr().empty() ? 1 : 0) ==
217                  1,
218              "Exactly one of type, type_attr must be set", suffix);
219   } else {
220     const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
221                                 (!arg.type_attr().empty() ? 1 : 0) +
222                                 (!arg.type_list_attr().empty() ? 1 : 0);
223     VALIDATE(num_type_fields == 1,
224              "Exactly one of type, type_attr, type_list_attr must be set",
225              suffix);
226   }
227 
228   if (!arg.type_attr().empty()) {
229     const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
230     VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
231              suffix);
232     VALIDATE(attr->type() == "type", "Attr '", attr->name(),
233              "' used as type_attr", suffix, " has type ", attr->type(),
234              " != type");
235   } else if (!arg.type_list_attr().empty()) {
236     const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
237     VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
238              suffix);
239     VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
240              "' used as type_list_attr", suffix, " has type ", attr->type(),
241              " != list(type)");
242   } else {
243     // All argument types should be non-reference types at this point.
244     // ArgDef.is_ref is set to true for reference arguments.
245     VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
246              DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
247   }
248 
249   return Status::OK();
250 }
251 
IsValidOpName(StringPiece sp)252 bool IsValidOpName(StringPiece sp) {
253   using ::tensorflow::strings::Scanner;
254 
255   Scanner scanner(sp);
256   scanner.One(Scanner::UPPERLETTER).Any(Scanner::LETTER_DIGIT_UNDERSCORE);
257 
258   while (true) {
259     if (!scanner.GetResult())  // Some error in previous iteration.
260       return false;
261     if (scanner.empty())  // No error, but nothing left, good.
262       return true;
263 
264     // Absorb another name/namespace, starting with a '>'
265     scanner.One(Scanner::RANGLE)
266         .One(Scanner::UPPERLETTER)
267         .Any(Scanner::LETTER_DIGIT_UNDERSCORE);
268   }
269 }
270 
ValidateOpDef(const OpDef & op_def)271 Status ValidateOpDef(const OpDef& op_def) {
272   if (!absl::StartsWith(op_def.name(), "_")) {
273     VALIDATE(IsValidOpName(op_def.name()), "Invalid name: ", op_def.name(),
274              " (Did you use CamelCase?)");
275   }
276 
277   std::set<string> names;  // for detecting duplicate names
278   for (const auto& attr : op_def.attr()) {
279     // Validate name
280     VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
281              "Duplicate name: ", attr.name());
282     DataType dt;
283     VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
284              attr.name(), " that matches a data type");
285 
286     // Validate type
287     StringPiece type(attr.type());
288     bool is_list = absl::ConsumePrefix(&type, "list(");
289     bool found = false;
290     for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
291                               "tensor", "func"}) {
292       if (absl::ConsumePrefix(&type, valid)) {
293         found = true;
294         break;
295       }
296     }
297     VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
298              "'");
299     if (is_list) {
300       VALIDATE(absl::ConsumePrefix(&type, ")"),
301                "'list(' is missing ')' in attr ", attr.name(), "'s type ",
302                attr.type());
303     }
304     VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
305              attr.name(), "'s type ", attr.type());
306 
307     // Validate minimum
308     if (attr.has_minimum()) {
309       VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
310                "' has minimum for unsupported type ", attr.type());
311       if (is_list) {
312         VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
313                  "' with list type must have a non-negative minimum, not ",
314                  attr.minimum());
315       }
316     } else {
317       VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
318                "' with has_minimum = false but minimum ", attr.minimum(),
319                " not equal to default of 0");
320     }
321 
322     // Validate allowed_values
323     if (attr.has_allowed_values()) {
324       const string list_type =
325           is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
326       TF_RETURN_WITH_CONTEXT_IF_ERROR(
327           AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
328           attr.name(), "' in Op '", op_def.name(), "'");
329     }
330 
331     // Validate default_value (after we have validated the rest of the attr,
332     // so we can use ValidateAttrValue()).
333     if (attr.has_default_value()) {
334       TF_RETURN_WITH_CONTEXT_IF_ERROR(
335           ValidateAttrValue(attr.default_value(), attr), " in Op '",
336           op_def.name(), "'");
337     }
338   }
339 
340   for (const auto& arg : op_def.input_arg()) {
341     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
342   }
343 
344   for (const auto& arg : op_def.output_arg()) {
345     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
346   }
347 
348   return Status::OK();
349 }
350 
351 #undef VALIDATE
352 
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)353 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
354   if (op_def.has_deprecation()) {
355     const OpDeprecation& dep = op_def.deprecation();
356     if (graph_def_version >= dep.version()) {
357       return errors::Unimplemented(
358           "Op ", op_def.name(), " is not available in GraphDef version ",
359           graph_def_version, ". It has been removed in version ", dep.version(),
360           ". ", dep.explanation(), ".");
361     } else {
362       // Warn only once for each op name, and do it in a threadsafe manner.
363       static mutex mu(LINKER_INITIALIZED);
364       static std::unordered_set<string> warned;
365       bool warn;
366       {
367         mutex_lock lock(mu);
368         warn = warned.insert(op_def.name()).second;
369       }
370       if (warn) {
371         LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
372                      << " It will cease to work in GraphDef version "
373                      << dep.version() << ". " << dep.explanation() << ".";
374       }
375     }
376   }
377   return Status::OK();
378 }
379 
380 namespace {
381 
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)382 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
383   string ret;
384   for (const OpDef::ArgDef& arg : args) {
385     if (!ret.empty()) strings::StrAppend(&ret, ", ");
386     strings::StrAppend(&ret, arg.name(), ":");
387     if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
388     if (!arg.number_attr().empty()) {
389       strings::StrAppend(&ret, arg.number_attr(), "*");
390     }
391     if (arg.type() != DT_INVALID) {
392       strings::StrAppend(&ret, DataTypeString(arg.type()));
393     } else {
394       strings::StrAppend(&ret, arg.type_attr());
395     }
396     if (arg.is_ref()) strings::StrAppend(&ret, ")");
397   }
398   return ret;
399 }
400 
401 }  // namespace
402 
SummarizeOpDef(const OpDef & op_def)403 string SummarizeOpDef(const OpDef& op_def) {
404   string ret = strings::StrCat("Op<name=", op_def.name());
405   strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
406                      " -> ", SummarizeArgs(op_def.output_arg()));
407   for (int i = 0; i < op_def.attr_size(); ++i) {
408     strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
409                        op_def.attr(i).type());
410     if (op_def.attr(i).has_default_value()) {
411       strings::StrAppend(&ret, ",default=",
412                          SummarizeAttrValue(op_def.attr(i).default_value()));
413     }
414     if (op_def.attr(i).has_minimum()) {
415       strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
416     }
417     if (op_def.attr(i).has_allowed_values()) {
418       strings::StrAppend(&ret, ",allowed=",
419                          SummarizeAttrValue(op_def.attr(i).allowed_values()));
420     }
421   }
422   if (op_def.is_commutative()) {
423     strings::StrAppend(&ret, "; is_commutative=true");
424   }
425   if (op_def.is_aggregate()) {
426     strings::StrAppend(&ret, "; is_aggregate=true");
427   }
428   if (op_def.is_stateful()) {
429     strings::StrAppend(&ret, "; is_stateful=true");
430   }
431   if (op_def.allows_uninitialized_input()) {
432     strings::StrAppend(&ret, "; allows_uninitialized_input=true");
433   }
434   strings::StrAppend(&ret, ">");
435   return ret;
436 }
437 
438 namespace {
439 
440 // Returns true if every element of `sub` is contained in `super`.
441 template <class T>
IsSubsetOf(const T & sub,const T & super)442 bool IsSubsetOf(const T& sub, const T& super) {
443   for (const auto& o : sub) {
444     bool found = false;
445     for (const auto& n : super) {
446       if (o == n) {
447         found = true;
448         break;
449       }
450     }
451     if (!found) return false;
452   }
453   return true;
454 }
455 
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)456 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
457                      const OpDef::AttrDef& new_attr) {
458   // Anything -> no restriction : not more restrictive.
459   if (!new_attr.has_allowed_values()) return false;
460   // No restriction -> restriction : more restrictive.
461   if (!old_attr.has_allowed_values()) return true;
462   // If anything that was previously allowed is no longer allowed:
463   // more restrictive.
464   if (!IsSubsetOf(old_attr.allowed_values().list().type(),
465                   new_attr.allowed_values().list().type())) {
466     return true;
467   }
468   if (!IsSubsetOf(old_attr.allowed_values().list().s(),
469                   new_attr.allowed_values().list().s())) {
470     return true;
471   }
472   return false;
473 }
474 
AllowedStr(const OpDef::AttrDef & attr)475 string AllowedStr(const OpDef::AttrDef& attr) {
476   if (!attr.has_allowed_values()) return "no restriction";
477   return SummarizeAttrValue(attr.allowed_values());
478 }
479 
DefaultAttrStr(const OpDef::AttrDef & attr)480 string DefaultAttrStr(const OpDef::AttrDef& attr) {
481   if (!attr.has_default_value()) return "no default";
482   return SummarizeAttrValue(attr.default_value());
483 }
484 
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)485 bool HigherMinimum(const OpDef::AttrDef& old_attr,
486                    const OpDef::AttrDef& new_attr) {
487   // Anything -> no restriction : not more restrictive.
488   if (!new_attr.has_minimum()) return false;
489   // No restriction -> restriction : more restrictive.
490   if (!old_attr.has_minimum()) return true;
491   // If anything that was previously allowed is no longer allowed:
492   // more restrictive.
493   return new_attr.minimum() > old_attr.minimum();
494 }
495 
MinStr(const OpDef::AttrDef & attr)496 string MinStr(const OpDef::AttrDef& attr) {
497   if (!attr.has_minimum()) return "no minimum";
498   return strings::StrCat(attr.minimum());
499 }
500 
501 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)502 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
503   for (const auto& attr : op_def.attr()) {
504     (*attr_map)[attr.name()] = &attr;
505   }
506 }
507 
508 // Add a comma to *s every call but the first (*add_comma should be
509 // initialized to false).
AddComma(string * s,bool * add_comma)510 void AddComma(string* s, bool* add_comma) {
511   if (*add_comma) {
512     strings::StrAppend(s, ", ");
513   } else {
514     *add_comma = true;
515   }
516 }
517 
518 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)519 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
520   if (name) {
521     strings::StrAppend(s, arg.name(), ":");
522   }
523 }
524 
525 // Compute a signature for either inputs or outputs that will be the
526 // same for both the old and new OpDef if they are compatible.  We
527 // assume that new_attrs is a superset of old_attrs, and that any attr
528 // in the difference has a default.  Our strategy is to make a list of
529 // types, where the types are things like:
530 // * "int32", "float", etc.,
531 // * "T" for some attr "T" in old_attrs, or
532 // * "N * type" for "N" either some attr in old_attrs.
533 //
534 // We get the types by either using the attrs in args if they are in
535 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)536 string ComputeArgSignature(
537     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
538     const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
539     bool names) {
540   string s;
541   bool add_comma = false;
542   for (const OpDef::ArgDef& arg : args) {
543     if (!arg.type_list_attr().empty()) {
544       const OpDef::AttrDef* old_attr =
545           gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
546       if (old_attr) {
547         // Both old and new have the list(type) attr, so can use it directly.
548         AddComma(&s, &add_comma);
549         AddName(&s, names, arg);
550         strings::StrAppend(&s, arg.type_list_attr());
551         ref->push_back(arg.is_ref());
552       } else {
553         // Missing the list(type) attr in the old, so use the default
554         // value for the attr from new instead.
555         const OpDef::AttrDef* new_attr =
556             gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
557         const auto& type_list = new_attr->default_value().list().type();
558         if (type_list.empty()) continue;
559         for (int i = 0; i < type_list.size(); ++i) {
560           AddComma(&s, &add_comma);
561           AddName(&s, names, arg);
562           strings::StrAppend(
563               &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
564           ref->push_back(arg.is_ref());
565         }
566       }
567     } else {
568       int num = 1;  // How many input/outputs does this represent?
569       string type;  // What is the type of this arg?
570       AddName(&type, names, arg);
571       if (!arg.number_attr().empty()) {
572         // N * type case.
573         const OpDef::AttrDef* old_attr =
574             gtl::FindPtrOrNull(old_attrs, arg.number_attr());
575         if (old_attr) {
576           // Both old and new have the number attr, so can use it directly.
577           strings::StrAppend(&type, arg.number_attr(), " * ");
578         } else {
579           // Missing the number attr in the old, so use the default
580           // value for the attr from new instead.
581           const OpDef::AttrDef* new_attr =
582               gtl::FindPtrOrNull(new_attrs, arg.number_attr());
583           num = new_attr->default_value().i();
584         }
585       }
586 
587       if (arg.type() != DT_INVALID) {
588         // int32, float, etc. case
589         strings::StrAppend(&type, DataTypeString(arg.type()));
590       } else {
591         const OpDef::AttrDef* old_attr =
592             gtl::FindPtrOrNull(old_attrs, arg.type_attr());
593         if (old_attr) {
594           // Both old and new have the type attr, so can use it directly.
595           strings::StrAppend(&type, arg.type_attr());
596         } else {
597           // Missing the type attr in the old, so use the default
598           // value for the attr from new instead.
599           const OpDef::AttrDef* new_attr =
600               gtl::FindPtrOrNull(new_attrs, arg.type_attr());
601           strings::StrAppend(&type,
602                              DataTypeString(new_attr->default_value().type()));
603         }
604       }
605 
606       // Record `num` * `type` in the signature.
607       for (int i = 0; i < num; ++i) {
608         AddComma(&s, &add_comma);
609         strings::StrAppend(&s, type);
610         ref->push_back(arg.is_ref());
611       }
612     }
613   }
614 
615   return s;
616 }
617 
618 }  // namespace
619 
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)620 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
621 #define VALIDATE(CONDITION, ...)                                            \
622   if (!(CONDITION)) {                                                       \
623     return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
624                                    "; old: ", SummarizeOpDef(old_op),       \
625                                    "; new: ", SummarizeOpDef(new_op));      \
626   }
627 
628   VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
629 
630   AttrMap new_attrs, old_attrs;
631   FillAttrMap(old_op, &old_attrs);
632   FillAttrMap(new_op, &new_attrs);
633   for (const auto& old_attr : old_op.attr()) {
634     const OpDef::AttrDef* new_attr =
635         gtl::FindPtrOrNull(new_attrs, old_attr.name());
636     VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
637     VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
638              "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
639              "'");
640     VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
641              "' has a stricter set of allowed values; from ",
642              AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
643     VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
644              "' has a higher minimum; from ", MinStr(old_attr), " to ",
645              MinStr(*new_attr));
646   }
647 
648   for (const auto& new_attr : new_op.attr()) {
649     const OpDef::AttrDef* old_attr =
650         gtl::FindPtrOrNull(old_attrs, new_attr.name());
651     VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
652              new_attr.name(), "' added without default");
653   }
654 
655   std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
656   const string old_in_sig = ComputeArgSignature(
657       old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
658   const string new_in_sig = ComputeArgSignature(
659       new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
660   VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
661            "' vs. '", new_in_sig, "'");
662   VALIDATE(old_in_ref.size() == new_in_ref.size(),  // Should not happen
663            "Unexpected change in input ref lists.");
664   for (int i = 0, end = old_in_ref.size(); i < end; ++i) {
665     // Allowed to remove "ref" from an input (or leave it unchanged).
666     VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
667              " changed from non-ref to ref");
668   }
669 
670   const string old_out_sig =
671       ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
672                           &old_out_ref, true /* names */);
673   const string new_out_sig =
674       ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
675                           &new_out_ref, true /* names */);
676   VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
677            old_out_sig, "' vs. '", new_out_sig, "'");
678   VALIDATE(old_out_ref.size() == new_out_ref.size(),  // Should not happen
679            "Unexpected change in output ref lists");
680   for (int i = 0, end = old_out_ref.size(); i < end; ++i) {
681     // Allowed to add "ref" to an output (or leave it unchanged).
682     VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
683              " changed from ref to non-ref");
684   }
685 
686   return Status::OK();
687 }
688 
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)689 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
690                                    const OpDef& penultimate_op,
691                                    const OpDef& new_op) {
692   AttrMap new_attrs, old_attrs;
693   FillAttrMap(old_op, &old_attrs);
694   FillAttrMap(new_op, &new_attrs);
695 
696   for (const auto& penultimate_attr : penultimate_op.attr()) {
697     const OpDef::AttrDef* old_attr =
698         gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
699     if (old_attr != nullptr) continue;  // attr wasn't added
700     const OpDef::AttrDef* new_attr =
701         gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
702 
703     // These shouldn't happen if the op passed OpDefCompatible().
704     if (new_attr == nullptr) {
705       return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
706                                      "' in op: ", SummarizeOpDef(new_op));
707     }
708     if (!penultimate_attr.has_default_value() ||
709         !new_attr->has_default_value()) {
710       return errors::InvalidArgument("Missing default for attr '",
711                                      penultimate_attr.name(),
712                                      "' in op: ", SummarizeOpDef(new_op));
713     }
714 
715     // Actually test that the attr's default value hasn't changed.
716     if (!AreAttrValuesEqual(penultimate_attr.default_value(),
717                             new_attr->default_value())) {
718       return errors::InvalidArgument(
719           "Can't change default value for attr '", penultimate_attr.name(),
720           "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
721           " in op: ", SummarizeOpDef(new_op));
722     }
723   }
724 
725   return Status::OK();
726 }
727 
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)728 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
729   AttrMap new_attrs, old_attrs;
730   FillAttrMap(old_op, &old_attrs);
731   FillAttrMap(new_op, &new_attrs);
732 
733   for (const auto& old_attr : old_op.attr()) {
734     const OpDef::AttrDef* new_attr =
735         gtl::FindPtrOrNull(new_attrs, old_attr.name());
736     if (new_attr == nullptr) continue;
737     if (new_attr->has_default_value() && !old_attr.has_default_value()) {
738       continue;  // Adding new default values is safe.
739     }
740     if (old_attr.has_default_value() && !new_attr->has_default_value()) {
741       return errors::InvalidArgument(
742           "Attr '", old_attr.name(), "' has removed it's default; ", "from ",
743           DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
744     }
745     if (old_attr.has_default_value() &&
746         !AreAttrValuesEqual(old_attr.default_value(),
747                             new_attr->default_value())) {
748       return errors::InvalidArgument(
749           "Attr '", old_attr.name(), "' has changed it's default value; ",
750           "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
751     }
752   }
753 
754   return Status::OK();
755 }
756 
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)757 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
758   for (int i = 0; i < op_def->input_arg_size(); ++i) {
759     op_def->mutable_input_arg(i)->clear_description();
760   }
761   for (int i = 0; i < op_def->output_arg_size(); ++i) {
762     op_def->mutable_output_arg(i)->clear_description();
763   }
764   for (int i = 0; i < op_def->attr_size(); ++i) {
765     op_def->mutable_attr(i)->clear_description();
766   }
767   op_def->clear_summary();
768   op_def->clear_description();
769 }
770 
RemoveDescriptionsFromOpDef(OpDef * op_def)771 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
772   RemoveNonDeprecationDescriptionsFromOpDef(op_def);
773   if (op_def->has_deprecation()) {
774     op_def->mutable_deprecation()->clear_explanation();
775   }
776 }
777 
RemoveDescriptionsFromOpList(OpList * op_list)778 void RemoveDescriptionsFromOpList(OpList* op_list) {
779   for (int i = 0; i < op_list->op_size(); ++i) {
780     OpDef* op_def = op_list->mutable_op(i);
781     RemoveDescriptionsFromOpDef(op_def);
782   }
783 }
784 
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)785 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
786   if (std::is_base_of<protobuf::Message, OpDef::AttrDef>()) {
787     DCHECK_EQ(7, reinterpret_cast<const protobuf::Message*>(&a1)
788                      ->GetDescriptor()
789                      ->field_count())
790         << "Please modify these equality and hash functions to reflect the "
791            "changes to the AttrDef protobuf";
792   }
793 
794   if (a1.name() != a2.name()) return false;
795   if (a1.type() != a2.type()) return false;
796   if (a1.description() != a2.description()) return false;
797   if (a1.has_minimum() != a2.has_minimum()) return false;
798   if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
799   if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
800   if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
801     return false;
802   return true;
803 }
804 
AttrDefHash(const OpDef::AttrDef & a)805 uint64 AttrDefHash(const OpDef::AttrDef& a) {
806   uint64 h = Hash64(a.name());
807   h = Hash64(a.type().data(), a.type().size(), h);
808   h = Hash64Combine(AttrValueHash(a.default_value()), h);
809   h = Hash64(a.description().data(), a.description().size(), h);
810   h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
811   h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
812   h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
813   return h;
814 }
815 
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)816 bool RepeatedAttrDefEqual(
817     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
818     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
819   std::unordered_map<string, const OpDef::AttrDef*> a1_set;
820   for (const OpDef::AttrDef& def : a1) {
821     DCHECK(a1_set.find(def.name()) == a1_set.end())
822         << "AttrDef names must be unique, but '" << def.name()
823         << "' appears more than once";
824     a1_set[def.name()] = &def;
825   }
826   for (const OpDef::AttrDef& def : a2) {
827     auto iter = a1_set.find(def.name());
828     if (iter == a1_set.end()) return false;
829     if (!AttrDefEqual(*iter->second, def)) return false;
830     a1_set.erase(iter);
831   }
832   if (!a1_set.empty()) return false;
833   return true;
834 }
835 
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)836 uint64 RepeatedAttrDefHash(
837     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
838   // Insert AttrDefs into map to deterministically sort by name
839   std::map<string, const OpDef::AttrDef*> a_set;
840   for (const OpDef::AttrDef& def : a) {
841     a_set[def.name()] = &def;
842   }
843   // Iterate and combines hashes of keys and values
844   uint64 h = 0xDECAFCAFFE;
845   for (const auto& pair : a_set) {
846     h = Hash64(pair.first.data(), pair.first.size(), h);
847     h = Hash64Combine(AttrDefHash(*pair.second), h);
848   }
849   return h;
850 }
851 
OpDefEqual(const OpDef & o1,const OpDef & o2)852 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
853   // attr order doesn't matter.
854   // Compare it separately here instead of serializing below.
855   if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
856 
857   // `control_output` order doesn't matter.
858   std::set<string> control_output1(o1.control_output().begin(),
859                                    o1.control_output().end());
860   std::set<string> control_output2(o2.control_output().begin(),
861                                    o2.control_output().end());
862   if (control_output1 != control_output2) return false;
863 
864   // Clear `attr` and `control_output` fields, serialize, and compare serialized
865   // strings.
866   OpDef o1_copy = o1;
867   OpDef o2_copy = o2;
868   o1_copy.clear_attr();
869   o1_copy.clear_control_output();
870   o2_copy.clear_attr();
871   o2_copy.clear_control_output();
872 
873   return AreSerializedProtosEqual(o1_copy, o2_copy);
874 }
875 
OpDefHash(const OpDef & o)876 uint64 OpDefHash(const OpDef& o) {
877   uint64 h = RepeatedAttrDefHash(o.attr());
878 
879   // Compute deterministic order-independent control outputs hash.
880   std::set<string> control_output(o.control_output().begin(),
881                                   o.control_output().end());
882   for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co));
883 
884   OpDef o_copy = o;
885   o_copy.clear_attr();
886   o_copy.clear_control_output();
887   return DeterministicProtoHash64(o_copy, h);
888 }
889 
890 }  // namespace tensorflow
891