• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op_def_util.h"
17 
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/op_def.pb_text.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/proto_serialization.h"
30 #include "tensorflow/core/lib/strings/scanner.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace tensorflow {
38 namespace {  // ------ Helper functions ------
39 
HasAttrStyleType(const OpDef::ArgDef & arg)40 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
41   return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
42          !arg.type_list_attr().empty();
43 }
44 
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)45 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
46   const AttrValue& allowed_values(attr.allowed_values());
47   for (auto allowed : allowed_values.list().type()) {
48     if (dt == allowed) {
49       return Status::OK();
50     }
51   }
52   string allowed_str;
53   for (int i = 0; i < allowed_values.list().type_size(); ++i) {
54     if (!allowed_str.empty()) {
55       strings::StrAppend(&allowed_str, ", ");
56     }
57     strings::StrAppend(&allowed_str,
58                        DataTypeString(allowed_values.list().type(i)));
59   }
60   return errors::InvalidArgument(
61       "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
62       " is not in the list of allowed values: ", allowed_str);
63 }
64 
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)65 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
66   const AttrValue& allowed_values(attr.allowed_values());
67   for (const auto& allowed : allowed_values.list().s()) {
68     if (str == allowed) {
69       return Status::OK();
70     }
71   }
72   string allowed_str;
73   for (const string& allowed : allowed_values.list().s()) {
74     if (!allowed_str.empty()) {
75       strings::StrAppend(&allowed_str, ", ");
76     }
77     strings::StrAppend(&allowed_str, "\"", allowed, "\"");
78   }
79   return errors::InvalidArgument(
80       "Value for attr '", attr.name(), "' of \"", str,
81       "\" is not in the list of allowed values: ", allowed_str);
82 }
83 
84 }  // namespace
85 
86 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)87 Status ValidateAttrValue(const AttrValue& attr_value,
88                          const OpDef::AttrDef& attr) {
89   // Is it a valid value?
90   TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
91                                   " for attr '", attr.name(), "'");
92 
93   // Does the value satisfy the minimum constraint in the AttrDef?
94   if (attr.has_minimum()) {
95     if (attr.type() == "int") {
96       if (attr_value.i() < attr.minimum()) {
97         return errors::InvalidArgument(
98             "Value for attr '", attr.name(), "' of ", attr_value.i(),
99             " must be at least minimum ", attr.minimum());
100       }
101     } else {
102       int length = -1;
103       if (attr.type() == "list(string)") {
104         length = attr_value.list().s_size();
105       } else if (attr.type() == "list(int)") {
106         length = attr_value.list().i_size();
107       } else if (attr.type() == "list(float)") {
108         length = attr_value.list().f_size();
109       } else if (attr.type() == "list(bool)") {
110         length = attr_value.list().b_size();
111       } else if (attr.type() == "list(type)") {
112         length = attr_value.list().type_size();
113       } else if (attr.type() == "list(shape)") {
114         length = attr_value.list().shape_size();
115       } else if (attr.type() == "list(tensor)") {
116         length = attr_value.list().tensor_size();
117       } else if (attr.type() == "list(func)") {
118         length = attr_value.list().func_size();
119       }
120       if (length < attr.minimum()) {
121         return errors::InvalidArgument(
122             "Length for attr '", attr.name(), "' of ", length,
123             " must be at least minimum ", attr.minimum());
124       }
125     }
126   }
127 
128   // Does the value satisfy the allowed_value constraint in the AttrDef?
129   if (attr.has_allowed_values()) {
130     if (attr.type() == "type") {
131       TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
132     } else if (attr.type() == "list(type)") {
133       for (int dt : attr_value.list().type()) {
134         TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
135       }
136     } else if (attr.type() == "string") {
137       TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
138     } else if (attr.type() == "list(string)") {
139       for (const string& str : attr_value.list().s()) {
140         TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
141       }
142     } else {
143       return errors::Unimplemented(
144           "Support for allowed_values not implemented for type ", attr.type());
145     }
146   }
147   return Status::OK();
148 }
149 
FindAttr(StringPiece name,const OpDef & op_def)150 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
151   for (int i = 0; i < op_def.attr_size(); ++i) {
152     if (op_def.attr(i).name() == name) {
153       return &op_def.attr(i);
154     }
155   }
156   return nullptr;
157 }
158 
FindAttrMutable(StringPiece name,OpDef * op_def)159 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
160   for (int i = 0; i < op_def->attr_size(); ++i) {
161     if (op_def->attr(i).name() == name) {
162       return op_def->mutable_attr(i);
163     }
164   }
165   return nullptr;
166 }
167 
FindInputArg(StringPiece name,const OpDef & op_def)168 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
169   for (int i = 0; i < op_def.input_arg_size(); ++i) {
170     if (op_def.input_arg(i).name() == name) {
171       return &op_def.input_arg(i);
172     }
173   }
174   return nullptr;
175 }
176 
FindInputArg(StringPiece name,const ApiDef & api_def)177 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
178   for (int i = 0; i < api_def.in_arg_size(); ++i) {
179     if (api_def.in_arg(i).name() == name) {
180       return &api_def.in_arg(i);
181     }
182   }
183   return nullptr;
184 }
185 
186 #define VALIDATE(EXPR, ...)                                            \
187   do {                                                                 \
188     if (!(EXPR)) {                                                     \
189       return errors::InvalidArgument(                                  \
190           __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \
191     }                                                                  \
192   } while (false)
193 
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)194 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
195                           bool output, std::set<string>* names) {
196   const string suffix = strings::StrCat(
197       output ? " for output '" : " for input '", arg.name(), "'");
198   VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
199            "Duplicate name: ", arg.name());
200   VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
201 
202   if (!arg.number_attr().empty()) {
203     const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
204     VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
205              suffix);
206     VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
207              suffix, " has type ", attr->type(), " != int");
208     VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
209              suffix, " must have minimum");
210     VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
211              suffix, " must have minimum >= 0");
212     VALIDATE(arg.type_list_attr().empty(),
213              "Can't have both number_attr and type_list_attr", suffix);
214     VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
215                      (!arg.type_attr().empty() ? 1 : 0) ==
216                  1,
217              "Exactly one of type, type_attr must be set", suffix);
218   } else {
219     const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
220                                 (!arg.type_attr().empty() ? 1 : 0) +
221                                 (!arg.type_list_attr().empty() ? 1 : 0);
222     VALIDATE(num_type_fields == 1,
223              "Exactly one of type, type_attr, type_list_attr must be set",
224              suffix);
225   }
226 
227   if (!arg.type_attr().empty()) {
228     const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
229     VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
230              suffix);
231     VALIDATE(attr->type() == "type", "Attr '", attr->name(),
232              "' used as type_attr", suffix, " has type ", attr->type(),
233              " != type");
234   } else if (!arg.type_list_attr().empty()) {
235     const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
236     VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
237              suffix);
238     VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
239              "' used as type_list_attr", suffix, " has type ", attr->type(),
240              " != list(type)");
241   } else {
242     // All argument types should be non-reference types at this point.
243     // ArgDef.is_ref is set to true for reference arguments.
244     VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
245              DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
246   }
247 
248   return Status::OK();
249 }
250 
ValidateOpDef(const OpDef & op_def)251 Status ValidateOpDef(const OpDef& op_def) {
252   using ::tensorflow::strings::Scanner;
253 
254   if (!str_util::StartsWith(op_def.name(), "_")) {
255     VALIDATE(Scanner(op_def.name())
256                  .One(Scanner::UPPERLETTER)
257                  .Any(Scanner::LETTER_DIGIT)
258                  .Eos()
259                  .GetResult(),
260              "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
261   }
262 
263   std::set<string> names;  // for detecting duplicate names
264   for (const auto& attr : op_def.attr()) {
265     // Validate name
266     VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
267              "Duplicate name: ", attr.name());
268     DataType dt;
269     VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
270              attr.name(), " that matches a data type");
271 
272     // Validate type
273     StringPiece type(attr.type());
274     bool is_list = str_util::ConsumePrefix(&type, "list(");
275     bool found = false;
276     for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
277                               "tensor", "func"}) {
278       if (str_util::ConsumePrefix(&type, valid)) {
279         found = true;
280         break;
281       }
282     }
283     VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
284              "'");
285     if (is_list) {
286       VALIDATE(str_util::ConsumePrefix(&type, ")"),
287                "'list(' is missing ')' in attr ", attr.name(), "'s type ",
288                attr.type());
289     }
290     VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
291              attr.name(), "'s type ", attr.type());
292 
293     // Validate minimum
294     if (attr.has_minimum()) {
295       VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
296                "' has minimum for unsupported type ", attr.type());
297       if (is_list) {
298         VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
299                  "' with list type must have a non-negative minimum, not ",
300                  attr.minimum());
301       }
302     } else {
303       VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
304                "' with has_minimum = false but minimum ", attr.minimum(),
305                " not equal to default of 0");
306     }
307 
308     // Validate allowed_values
309     if (attr.has_allowed_values()) {
310       const string list_type =
311           is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
312       TF_RETURN_WITH_CONTEXT_IF_ERROR(
313           AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
314           attr.name(), "' in Op '", op_def.name(), "'");
315     }
316 
317     // Validate default_value (after we have validated the rest of the attr,
318     // so we can use ValidateAttrValue()).
319     if (attr.has_default_value()) {
320       TF_RETURN_WITH_CONTEXT_IF_ERROR(
321           ValidateAttrValue(attr.default_value(), attr), " in Op '",
322           op_def.name(), "'");
323     }
324   }
325 
326   for (const auto& arg : op_def.input_arg()) {
327     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
328   }
329 
330   for (const auto& arg : op_def.output_arg()) {
331     TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
332   }
333 
334   return Status::OK();
335 }
336 
337 #undef VALIDATE
338 
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)339 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
340   if (op_def.has_deprecation()) {
341     const OpDeprecation& dep = op_def.deprecation();
342     if (graph_def_version >= dep.version()) {
343       return errors::Unimplemented(
344           "Op ", op_def.name(), " is not available in GraphDef version ",
345           graph_def_version, ". It has been removed in version ", dep.version(),
346           ". ", dep.explanation(), ".");
347     } else {
348       // Warn only once for each op name, and do it in a threadsafe manner.
349       static mutex mu(LINKER_INITIALIZED);
350       static std::unordered_set<string> warned;
351       bool warn;
352       {
353         mutex_lock lock(mu);
354         warn = warned.insert(op_def.name()).second;
355       }
356       if (warn) {
357         LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
358                      << " It will cease to work in GraphDef version "
359                      << dep.version() << ". " << dep.explanation() << ".";
360       }
361     }
362   }
363   return Status::OK();
364 }
365 
366 namespace {
367 
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)368 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
369   string ret;
370   for (const OpDef::ArgDef& arg : args) {
371     if (!ret.empty()) strings::StrAppend(&ret, ", ");
372     strings::StrAppend(&ret, arg.name(), ":");
373     if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
374     if (!arg.number_attr().empty()) {
375       strings::StrAppend(&ret, arg.number_attr(), "*");
376     }
377     if (arg.type() != DT_INVALID) {
378       strings::StrAppend(&ret, DataTypeString(arg.type()));
379     } else {
380       strings::StrAppend(&ret, arg.type_attr());
381     }
382     if (arg.is_ref()) strings::StrAppend(&ret, ")");
383   }
384   return ret;
385 }
386 
387 }  // namespace
388 
SummarizeOpDef(const OpDef & op_def)389 string SummarizeOpDef(const OpDef& op_def) {
390   string ret = strings::StrCat("Op<name=", op_def.name());
391   strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
392                      " -> ", SummarizeArgs(op_def.output_arg()));
393   for (int i = 0; i < op_def.attr_size(); ++i) {
394     strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
395                        op_def.attr(i).type());
396     if (op_def.attr(i).has_default_value()) {
397       strings::StrAppend(&ret, ",default=",
398                          SummarizeAttrValue(op_def.attr(i).default_value()));
399     }
400     if (op_def.attr(i).has_minimum()) {
401       strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
402     }
403     if (op_def.attr(i).has_allowed_values()) {
404       strings::StrAppend(&ret, ",allowed=",
405                          SummarizeAttrValue(op_def.attr(i).allowed_values()));
406     }
407   }
408   if (op_def.is_commutative()) {
409     strings::StrAppend(&ret, "; is_commutative=true");
410   }
411   if (op_def.is_aggregate()) {
412     strings::StrAppend(&ret, "; is_aggregate=true");
413   }
414   if (op_def.is_stateful()) {
415     strings::StrAppend(&ret, "; is_stateful=true");
416   }
417   if (op_def.allows_uninitialized_input()) {
418     strings::StrAppend(&ret, "; allows_uninitialized_input=true");
419   }
420   strings::StrAppend(&ret, ">");
421   return ret;
422 }
423 
424 namespace {
425 
426 // Returns true if every element of `sub` is contained in `super`.
427 template <class T>
IsSubsetOf(const T & sub,const T & super)428 bool IsSubsetOf(const T& sub, const T& super) {
429   for (const auto& o : sub) {
430     bool found = false;
431     for (const auto& n : super) {
432       if (o == n) {
433         found = true;
434         break;
435       }
436     }
437     if (!found) return false;
438   }
439   return true;
440 }
441 
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)442 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
443                      const OpDef::AttrDef& new_attr) {
444   // Anything -> no restriction : not more restrictive.
445   if (!new_attr.has_allowed_values()) return false;
446   // No restriction -> restriction : more restrictive.
447   if (!old_attr.has_allowed_values()) return true;
448   // If anything that was previously allowed is no longer allowed:
449   // more restrictive.
450   if (!IsSubsetOf(old_attr.allowed_values().list().type(),
451                   new_attr.allowed_values().list().type())) {
452     return true;
453   }
454   if (!IsSubsetOf(old_attr.allowed_values().list().s(),
455                   new_attr.allowed_values().list().s())) {
456     return true;
457   }
458   return false;
459 }
460 
AllowedStr(const OpDef::AttrDef & attr)461 string AllowedStr(const OpDef::AttrDef& attr) {
462   if (!attr.has_allowed_values()) return "no restriction";
463   return SummarizeAttrValue(attr.allowed_values());
464 }
465 
DefaultAttrStr(const OpDef::AttrDef & attr)466 string DefaultAttrStr(const OpDef::AttrDef& attr) {
467   if (!attr.has_default_value()) return "no default";
468   return SummarizeAttrValue(attr.default_value());
469 }
470 
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)471 bool HigherMinimum(const OpDef::AttrDef& old_attr,
472                    const OpDef::AttrDef& new_attr) {
473   // Anything -> no restriction : not more restrictive.
474   if (!new_attr.has_minimum()) return false;
475   // No restriction -> restriction : more restrictive.
476   if (!old_attr.has_minimum()) return true;
477   // If anything that was previously allowed is no longer allowed:
478   // more restrictive.
479   return new_attr.minimum() > old_attr.minimum();
480 }
481 
MinStr(const OpDef::AttrDef & attr)482 string MinStr(const OpDef::AttrDef& attr) {
483   if (!attr.has_minimum()) return "no minimum";
484   return strings::StrCat(attr.minimum());
485 }
486 
487 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)488 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
489   for (const auto& attr : op_def.attr()) {
490     (*attr_map)[attr.name()] = &attr;
491   }
492 }
493 
494 // Add a comma to *s every call but the first (*add_comma should be
495 // initialized to false).
AddComma(string * s,bool * add_comma)496 void AddComma(string* s, bool* add_comma) {
497   if (*add_comma) {
498     strings::StrAppend(s, ", ");
499   } else {
500     *add_comma = true;
501   }
502 }
503 
504 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)505 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
506   if (name) {
507     strings::StrAppend(s, arg.name(), ":");
508   }
509 }
510 
511 // Compute a signature for either inputs or outputs that will be the
512 // same for both the old and new OpDef if they are compatible.  We
513 // assume that new_attrs is a superset of old_attrs, and that any attr
514 // in the difference has a default.  Our strategy is to make a list of
515 // types, where the types are things like:
516 // * "int32", "float", etc.,
517 // * "T" for some attr "T" in old_attrs, or
518 // * "N * type" for "N" either some attr in old_attrs.
519 //
520 // We get the types by either using the attrs in args if they are in
521 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)522 string ComputeArgSignature(
523     const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
524     const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
525     bool names) {
526   string s;
527   bool add_comma = false;
528   for (const OpDef::ArgDef& arg : args) {
529     if (!arg.type_list_attr().empty()) {
530       const OpDef::AttrDef* old_attr =
531           gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
532       if (old_attr) {
533         // Both old and new have the list(type) attr, so can use it directly.
534         AddComma(&s, &add_comma);
535         AddName(&s, names, arg);
536         strings::StrAppend(&s, arg.type_list_attr());
537         ref->push_back(arg.is_ref());
538       } else {
539         // Missing the list(type) attr in the old, so use the default
540         // value for the attr from new instead.
541         const OpDef::AttrDef* new_attr =
542             gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
543         const auto& type_list = new_attr->default_value().list().type();
544         if (type_list.empty()) continue;
545         for (int i = 0; i < type_list.size(); ++i) {
546           AddComma(&s, &add_comma);
547           AddName(&s, names, arg);
548           strings::StrAppend(
549               &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
550           ref->push_back(arg.is_ref());
551         }
552       }
553     } else {
554       int num = 1;  // How many input/outputs does this represent?
555       string type;  // What is the type of this arg?
556       AddName(&type, names, arg);
557       if (!arg.number_attr().empty()) {
558         // N * type case.
559         const OpDef::AttrDef* old_attr =
560             gtl::FindPtrOrNull(old_attrs, arg.number_attr());
561         if (old_attr) {
562           // Both old and new have the number attr, so can use it directly.
563           strings::StrAppend(&type, arg.number_attr(), " * ");
564         } else {
565           // Missing the number attr in the old, so use the default
566           // value for the attr from new instead.
567           const OpDef::AttrDef* new_attr =
568               gtl::FindPtrOrNull(new_attrs, arg.number_attr());
569           num = new_attr->default_value().i();
570         }
571       }
572 
573       if (arg.type() != DT_INVALID) {
574         // int32, float, etc. case
575         strings::StrAppend(&type, DataTypeString(arg.type()));
576       } else {
577         const OpDef::AttrDef* old_attr =
578             gtl::FindPtrOrNull(old_attrs, arg.type_attr());
579         if (old_attr) {
580           // Both old and new have the type attr, so can use it directly.
581           strings::StrAppend(&type, arg.type_attr());
582         } else {
583           // Missing the type attr in the old, so use the default
584           // value for the attr from new instead.
585           const OpDef::AttrDef* new_attr =
586               gtl::FindPtrOrNull(new_attrs, arg.type_attr());
587           strings::StrAppend(&type,
588                              DataTypeString(new_attr->default_value().type()));
589         }
590       }
591 
592       // Record `num` * `type` in the signature.
593       for (int i = 0; i < num; ++i) {
594         AddComma(&s, &add_comma);
595         strings::StrAppend(&s, type);
596         ref->push_back(arg.is_ref());
597       }
598     }
599   }
600 
601   return s;
602 }
603 
604 }  // namespace
605 
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)606 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
607 #define VALIDATE(CONDITION, ...)                                            \
608   if (!(CONDITION)) {                                                       \
609     return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
610                                    "; old: ", SummarizeOpDef(old_op),       \
611                                    "; new: ", SummarizeOpDef(new_op));      \
612   }
613 
614   VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
615 
616   AttrMap new_attrs, old_attrs;
617   FillAttrMap(old_op, &old_attrs);
618   FillAttrMap(new_op, &new_attrs);
619   for (const auto& old_attr : old_op.attr()) {
620     const OpDef::AttrDef* new_attr =
621         gtl::FindPtrOrNull(new_attrs, old_attr.name());
622     VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
623     VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
624              "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
625              "'");
626     VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
627              "' has a stricter set of allowed values; from ",
628              AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
629     VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
630              "' has a higher minimum; from ", MinStr(old_attr), " to ",
631              MinStr(*new_attr));
632   }
633 
634   for (const auto& new_attr : new_op.attr()) {
635     const OpDef::AttrDef* old_attr =
636         gtl::FindPtrOrNull(old_attrs, new_attr.name());
637     VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
638              new_attr.name(), "' added without default");
639   }
640 
641   std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
642   const string old_in_sig = ComputeArgSignature(
643       old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
644   const string new_in_sig = ComputeArgSignature(
645       new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
646   VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
647            "' vs. '", new_in_sig, "'");
648   VALIDATE(old_in_ref.size() == new_in_ref.size(),  // Should not happen
649            "Unexpected change in input ref lists.");
650   for (int i = 0; i < old_in_ref.size(); ++i) {
651     // Allowed to remove "ref" from an input (or leave it unchanged).
652     VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
653              " changed from non-ref to ref");
654   }
655 
656   const string old_out_sig =
657       ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
658                           &old_out_ref, true /* names */);
659   const string new_out_sig =
660       ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
661                           &new_out_ref, true /* names */);
662   VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
663            old_out_sig, "' vs. '", new_out_sig, "'");
664   VALIDATE(old_out_ref.size() == new_out_ref.size(),  // Should not happen
665            "Unexpected change in output ref lists");
666   for (int i = 0; i < old_out_ref.size(); ++i) {
667     // Allowed to add "ref" to an output (or leave it unchanged).
668     VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
669              " changed from ref to non-ref");
670   }
671 
672   return Status::OK();
673 }
674 
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)675 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
676                                    const OpDef& penultimate_op,
677                                    const OpDef& new_op) {
678   AttrMap new_attrs, old_attrs;
679   FillAttrMap(old_op, &old_attrs);
680   FillAttrMap(new_op, &new_attrs);
681 
682   for (const auto& penultimate_attr : penultimate_op.attr()) {
683     const OpDef::AttrDef* old_attr =
684         gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
685     if (old_attr != nullptr) continue;  // attr wasn't added
686     const OpDef::AttrDef* new_attr =
687         gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
688 
689     // These shouldn't happen if the op passed OpDefCompatible().
690     if (new_attr == nullptr) {
691       return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
692                                      "' in op: ", SummarizeOpDef(new_op));
693     }
694     if (!penultimate_attr.has_default_value() ||
695         !new_attr->has_default_value()) {
696       return errors::InvalidArgument("Missing default for attr '",
697                                      penultimate_attr.name(),
698                                      "' in op: ", SummarizeOpDef(new_op));
699     }
700 
701     // Actually test that the attr's default value hasn't changed.
702     if (!AreAttrValuesEqual(penultimate_attr.default_value(),
703                             new_attr->default_value())) {
704       return errors::InvalidArgument(
705           "Can't change default value for attr '", penultimate_attr.name(),
706           "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
707           " in op: ", SummarizeOpDef(new_op));
708     }
709   }
710 
711   return Status::OK();
712 }
713 
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)714 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
715   AttrMap new_attrs, old_attrs;
716   FillAttrMap(old_op, &old_attrs);
717   FillAttrMap(new_op, &new_attrs);
718 
719   for (const auto& old_attr : old_op.attr()) {
720     const OpDef::AttrDef* new_attr =
721         gtl::FindPtrOrNull(new_attrs, old_attr.name());
722     if (new_attr == nullptr) continue;
723     if (old_attr.has_default_value() != new_attr->has_default_value()) {
724       return errors::InvalidArgument(
725           "Attr '", old_attr.name(), "' has added/removed it's default; ",
726           "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
727     }
728     if (old_attr.has_default_value() &&
729         !AreAttrValuesEqual(old_attr.default_value(),
730                             new_attr->default_value())) {
731       return errors::InvalidArgument(
732           "Attr '", old_attr.name(), "' has changed it's default value; ",
733           "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
734     }
735   }
736 
737   return Status::OK();
738 }
739 
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)740 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
741   for (int i = 0; i < op_def->input_arg_size(); ++i) {
742     op_def->mutable_input_arg(i)->clear_description();
743   }
744   for (int i = 0; i < op_def->output_arg_size(); ++i) {
745     op_def->mutable_output_arg(i)->clear_description();
746   }
747   for (int i = 0; i < op_def->attr_size(); ++i) {
748     op_def->mutable_attr(i)->clear_description();
749   }
750   op_def->clear_summary();
751   op_def->clear_description();
752 }
753 
RemoveDescriptionsFromOpDef(OpDef * op_def)754 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
755   RemoveNonDeprecationDescriptionsFromOpDef(op_def);
756   if (op_def->has_deprecation()) {
757     op_def->mutable_deprecation()->clear_explanation();
758   }
759 }
760 
RemoveDescriptionsFromOpList(OpList * op_list)761 void RemoveDescriptionsFromOpList(OpList* op_list) {
762   for (int i = 0; i < op_list->op_size(); ++i) {
763     OpDef* op_def = op_list->mutable_op(i);
764     RemoveDescriptionsFromOpDef(op_def);
765   }
766 }
767 
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)768 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
769 #ifndef TENSORFLOW_LITE_PROTOS
770   DCHECK_EQ(7, a1.GetDescriptor()->field_count())
771       << "Please modify these equality and hash functions to reflect the "
772          "changes to the AttrDef protobuf";
773 #endif  // TENSORFLOW_LITE_PROTOS
774 
775   if (a1.name() != a2.name()) return false;
776   if (a1.type() != a2.type()) return false;
777   if (a1.description() != a2.description()) return false;
778   if (a1.has_minimum() != a2.has_minimum()) return false;
779   if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
780   if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
781   if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
782     return false;
783   return true;
784 }
785 
AttrDefHash(const OpDef::AttrDef & a)786 uint64 AttrDefHash(const OpDef::AttrDef& a) {
787   uint64 h = Hash64(a.name());
788   h = Hash64(a.type().data(), a.type().size(), h);
789   h = Hash64Combine(AttrValueHash(a.default_value()), h);
790   h = Hash64(a.description().data(), a.description().size(), h);
791   h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
792   h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
793   h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
794   return h;
795 }
796 
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)797 bool RepeatedAttrDefEqual(
798     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
799     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
800   std::unordered_map<string, const OpDef::AttrDef*> a1_set;
801   for (const OpDef::AttrDef& def : a1) {
802     DCHECK(a1_set.find(def.name()) == a1_set.end())
803         << "AttrDef names must be unique, but '" << def.name()
804         << "' appears more than once";
805     a1_set[def.name()] = &def;
806   }
807   for (const OpDef::AttrDef& def : a2) {
808     auto iter = a1_set.find(def.name());
809     if (iter == a1_set.end()) return false;
810     if (!AttrDefEqual(*iter->second, def)) return false;
811     a1_set.erase(iter);
812   }
813   if (!a1_set.empty()) return false;
814   return true;
815 }
816 
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)817 uint64 RepeatedAttrDefHash(
818     const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
819   // Insert AttrDefs into map to deterministically sort by name
820   std::map<string, const OpDef::AttrDef*> a_set;
821   for (const OpDef::AttrDef& def : a) {
822     a_set[def.name()] = &def;
823   }
824   // Iterate and combines hashes of keys and values
825   uint64 h = 0xDECAFCAFFE;
826   for (const auto& pair : a_set) {
827     h = Hash64(pair.first.data(), pair.first.size(), h);
828     h = Hash64Combine(AttrDefHash(*pair.second), h);
829   }
830   return h;
831 }
832 
OpDefEqual(const OpDef & o1,const OpDef & o2)833 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
834   // attr order doesn't matter.
835   // Compare it separately here instead of serializing below.
836   if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
837 
838   // `control_output` order doesn't matter.
839   std::set<string> control_output1(o1.control_output().begin(),
840                                    o1.control_output().end());
841   std::set<string> control_output2(o2.control_output().begin(),
842                                    o2.control_output().end());
843   if (control_output1 != control_output2) return false;
844 
845   // Clear `attr` and `control_output` fields, serialize, and compare serialized
846   // strings.
847   OpDef o1_copy = o1;
848   OpDef o2_copy = o2;
849   o1_copy.clear_attr();
850   o1_copy.clear_control_output();
851   o2_copy.clear_attr();
852   o2_copy.clear_control_output();
853 
854   return AreSerializedProtosEqual(o1_copy, o2_copy);
855 }
856 
OpDefHash(const OpDef & o)857 uint64 OpDefHash(const OpDef& o) {
858   uint64 h = RepeatedAttrDefHash(o.attr());
859 
860   // Compute deterministic order-independent control outputs hash.
861   std::set<string> control_output(o.control_output().begin(),
862                                   o.control_output().end());
863   for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co));
864 
865   OpDef o_copy = o;
866   o_copy.clear_attr();
867   o_copy.clear_control_output();
868   return DeterministicProtoHash64(o_copy, h);
869 }
870 
871 }  // namespace tensorflow
872