1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_util.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/op_def.pb_text.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/lib/core/errors.h"
26 #include "tensorflow/core/lib/core/stringpiece.h"
27 #include "tensorflow/core/lib/gtl/map_util.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/proto_serialization.h"
30 #include "tensorflow/core/lib/strings/scanner.h"
31 #include "tensorflow/core/lib/strings/str_util.h"
32 #include "tensorflow/core/lib/strings/strcat.h"
33 #include "tensorflow/core/platform/mutex.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace tensorflow {
38 namespace { // ------ Helper functions ------
39
HasAttrStyleType(const OpDef::ArgDef & arg)40 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
41 return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
42 !arg.type_list_attr().empty();
43 }
44
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)45 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
46 const AttrValue& allowed_values(attr.allowed_values());
47 for (auto allowed : allowed_values.list().type()) {
48 if (dt == allowed) {
49 return Status::OK();
50 }
51 }
52 string allowed_str;
53 for (int i = 0; i < allowed_values.list().type_size(); ++i) {
54 if (!allowed_str.empty()) {
55 strings::StrAppend(&allowed_str, ", ");
56 }
57 strings::StrAppend(&allowed_str,
58 DataTypeString(allowed_values.list().type(i)));
59 }
60 return errors::InvalidArgument(
61 "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
62 " is not in the list of allowed values: ", allowed_str);
63 }
64
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)65 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
66 const AttrValue& allowed_values(attr.allowed_values());
67 for (const auto& allowed : allowed_values.list().s()) {
68 if (str == allowed) {
69 return Status::OK();
70 }
71 }
72 string allowed_str;
73 for (const string& allowed : allowed_values.list().s()) {
74 if (!allowed_str.empty()) {
75 strings::StrAppend(&allowed_str, ", ");
76 }
77 strings::StrAppend(&allowed_str, "\"", allowed, "\"");
78 }
79 return errors::InvalidArgument(
80 "Value for attr '", attr.name(), "' of \"", str,
81 "\" is not in the list of allowed values: ", allowed_str);
82 }
83
84 } // namespace
85
86 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)87 Status ValidateAttrValue(const AttrValue& attr_value,
88 const OpDef::AttrDef& attr) {
89 // Is it a valid value?
90 TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
91 " for attr '", attr.name(), "'");
92
93 // Does the value satisfy the minimum constraint in the AttrDef?
94 if (attr.has_minimum()) {
95 if (attr.type() == "int") {
96 if (attr_value.i() < attr.minimum()) {
97 return errors::InvalidArgument(
98 "Value for attr '", attr.name(), "' of ", attr_value.i(),
99 " must be at least minimum ", attr.minimum());
100 }
101 } else {
102 int length = -1;
103 if (attr.type() == "list(string)") {
104 length = attr_value.list().s_size();
105 } else if (attr.type() == "list(int)") {
106 length = attr_value.list().i_size();
107 } else if (attr.type() == "list(float)") {
108 length = attr_value.list().f_size();
109 } else if (attr.type() == "list(bool)") {
110 length = attr_value.list().b_size();
111 } else if (attr.type() == "list(type)") {
112 length = attr_value.list().type_size();
113 } else if (attr.type() == "list(shape)") {
114 length = attr_value.list().shape_size();
115 } else if (attr.type() == "list(tensor)") {
116 length = attr_value.list().tensor_size();
117 } else if (attr.type() == "list(func)") {
118 length = attr_value.list().func_size();
119 }
120 if (length < attr.minimum()) {
121 return errors::InvalidArgument(
122 "Length for attr '", attr.name(), "' of ", length,
123 " must be at least minimum ", attr.minimum());
124 }
125 }
126 }
127
128 // Does the value satisfy the allowed_value constraint in the AttrDef?
129 if (attr.has_allowed_values()) {
130 if (attr.type() == "type") {
131 TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
132 } else if (attr.type() == "list(type)") {
133 for (int dt : attr_value.list().type()) {
134 TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
135 }
136 } else if (attr.type() == "string") {
137 TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
138 } else if (attr.type() == "list(string)") {
139 for (const string& str : attr_value.list().s()) {
140 TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
141 }
142 } else {
143 return errors::Unimplemented(
144 "Support for allowed_values not implemented for type ", attr.type());
145 }
146 }
147 return Status::OK();
148 }
149
FindAttr(StringPiece name,const OpDef & op_def)150 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
151 for (int i = 0; i < op_def.attr_size(); ++i) {
152 if (op_def.attr(i).name() == name) {
153 return &op_def.attr(i);
154 }
155 }
156 return nullptr;
157 }
158
FindAttrMutable(StringPiece name,OpDef * op_def)159 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
160 for (int i = 0; i < op_def->attr_size(); ++i) {
161 if (op_def->attr(i).name() == name) {
162 return op_def->mutable_attr(i);
163 }
164 }
165 return nullptr;
166 }
167
FindInputArg(StringPiece name,const OpDef & op_def)168 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
169 for (int i = 0; i < op_def.input_arg_size(); ++i) {
170 if (op_def.input_arg(i).name() == name) {
171 return &op_def.input_arg(i);
172 }
173 }
174 return nullptr;
175 }
176
FindInputArg(StringPiece name,const ApiDef & api_def)177 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
178 for (int i = 0; i < api_def.in_arg_size(); ++i) {
179 if (api_def.in_arg(i).name() == name) {
180 return &api_def.in_arg(i);
181 }
182 }
183 return nullptr;
184 }
185
186 #define VALIDATE(EXPR, ...) \
187 do { \
188 if (!(EXPR)) { \
189 return errors::InvalidArgument( \
190 __VA_ARGS__, "; in OpDef: ", ProtoShortDebugString(op_def)); \
191 } \
192 } while (false)
193
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)194 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
195 bool output, std::set<string>* names) {
196 const string suffix = strings::StrCat(
197 output ? " for output '" : " for input '", arg.name(), "'");
198 VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
199 "Duplicate name: ", arg.name());
200 VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
201
202 if (!arg.number_attr().empty()) {
203 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
204 VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
205 suffix);
206 VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
207 suffix, " has type ", attr->type(), " != int");
208 VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
209 suffix, " must have minimum");
210 VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
211 suffix, " must have minimum >= 0");
212 VALIDATE(arg.type_list_attr().empty(),
213 "Can't have both number_attr and type_list_attr", suffix);
214 VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
215 (!arg.type_attr().empty() ? 1 : 0) ==
216 1,
217 "Exactly one of type, type_attr must be set", suffix);
218 } else {
219 const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
220 (!arg.type_attr().empty() ? 1 : 0) +
221 (!arg.type_list_attr().empty() ? 1 : 0);
222 VALIDATE(num_type_fields == 1,
223 "Exactly one of type, type_attr, type_list_attr must be set",
224 suffix);
225 }
226
227 if (!arg.type_attr().empty()) {
228 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
229 VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
230 suffix);
231 VALIDATE(attr->type() == "type", "Attr '", attr->name(),
232 "' used as type_attr", suffix, " has type ", attr->type(),
233 " != type");
234 } else if (!arg.type_list_attr().empty()) {
235 const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
236 VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
237 suffix);
238 VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
239 "' used as type_list_attr", suffix, " has type ", attr->type(),
240 " != list(type)");
241 } else {
242 // All argument types should be non-reference types at this point.
243 // ArgDef.is_ref is set to true for reference arguments.
244 VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
245 DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
246 }
247
248 return Status::OK();
249 }
250
ValidateOpDef(const OpDef & op_def)251 Status ValidateOpDef(const OpDef& op_def) {
252 using ::tensorflow::strings::Scanner;
253
254 if (!str_util::StartsWith(op_def.name(), "_")) {
255 VALIDATE(Scanner(op_def.name())
256 .One(Scanner::UPPERLETTER)
257 .Any(Scanner::LETTER_DIGIT)
258 .Eos()
259 .GetResult(),
260 "Invalid name: ", op_def.name(), " (Did you use CamelCase?)");
261 }
262
263 std::set<string> names; // for detecting duplicate names
264 for (const auto& attr : op_def.attr()) {
265 // Validate name
266 VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
267 "Duplicate name: ", attr.name());
268 DataType dt;
269 VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
270 attr.name(), " that matches a data type");
271
272 // Validate type
273 StringPiece type(attr.type());
274 bool is_list = str_util::ConsumePrefix(&type, "list(");
275 bool found = false;
276 for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
277 "tensor", "func"}) {
278 if (str_util::ConsumePrefix(&type, valid)) {
279 found = true;
280 break;
281 }
282 }
283 VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
284 "'");
285 if (is_list) {
286 VALIDATE(str_util::ConsumePrefix(&type, ")"),
287 "'list(' is missing ')' in attr ", attr.name(), "'s type ",
288 attr.type());
289 }
290 VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
291 attr.name(), "'s type ", attr.type());
292
293 // Validate minimum
294 if (attr.has_minimum()) {
295 VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
296 "' has minimum for unsupported type ", attr.type());
297 if (is_list) {
298 VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
299 "' with list type must have a non-negative minimum, not ",
300 attr.minimum());
301 }
302 } else {
303 VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
304 "' with has_minimum = false but minimum ", attr.minimum(),
305 " not equal to default of 0");
306 }
307
308 // Validate allowed_values
309 if (attr.has_allowed_values()) {
310 const string list_type =
311 is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
312 TF_RETURN_WITH_CONTEXT_IF_ERROR(
313 AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
314 attr.name(), "' in Op '", op_def.name(), "'");
315 }
316
317 // Validate default_value (after we have validated the rest of the attr,
318 // so we can use ValidateAttrValue()).
319 if (attr.has_default_value()) {
320 TF_RETURN_WITH_CONTEXT_IF_ERROR(
321 ValidateAttrValue(attr.default_value(), attr), " in Op '",
322 op_def.name(), "'");
323 }
324 }
325
326 for (const auto& arg : op_def.input_arg()) {
327 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
328 }
329
330 for (const auto& arg : op_def.output_arg()) {
331 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
332 }
333
334 return Status::OK();
335 }
336
337 #undef VALIDATE
338
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)339 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
340 if (op_def.has_deprecation()) {
341 const OpDeprecation& dep = op_def.deprecation();
342 if (graph_def_version >= dep.version()) {
343 return errors::Unimplemented(
344 "Op ", op_def.name(), " is not available in GraphDef version ",
345 graph_def_version, ". It has been removed in version ", dep.version(),
346 ". ", dep.explanation(), ".");
347 } else {
348 // Warn only once for each op name, and do it in a threadsafe manner.
349 static mutex mu(LINKER_INITIALIZED);
350 static std::unordered_set<string> warned;
351 bool warn;
352 {
353 mutex_lock lock(mu);
354 warn = warned.insert(op_def.name()).second;
355 }
356 if (warn) {
357 LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
358 << " It will cease to work in GraphDef version "
359 << dep.version() << ". " << dep.explanation() << ".";
360 }
361 }
362 }
363 return Status::OK();
364 }
365
366 namespace {
367
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)368 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
369 string ret;
370 for (const OpDef::ArgDef& arg : args) {
371 if (!ret.empty()) strings::StrAppend(&ret, ", ");
372 strings::StrAppend(&ret, arg.name(), ":");
373 if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
374 if (!arg.number_attr().empty()) {
375 strings::StrAppend(&ret, arg.number_attr(), "*");
376 }
377 if (arg.type() != DT_INVALID) {
378 strings::StrAppend(&ret, DataTypeString(arg.type()));
379 } else {
380 strings::StrAppend(&ret, arg.type_attr());
381 }
382 if (arg.is_ref()) strings::StrAppend(&ret, ")");
383 }
384 return ret;
385 }
386
387 } // namespace
388
SummarizeOpDef(const OpDef & op_def)389 string SummarizeOpDef(const OpDef& op_def) {
390 string ret = strings::StrCat("Op<name=", op_def.name());
391 strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
392 " -> ", SummarizeArgs(op_def.output_arg()));
393 for (int i = 0; i < op_def.attr_size(); ++i) {
394 strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
395 op_def.attr(i).type());
396 if (op_def.attr(i).has_default_value()) {
397 strings::StrAppend(&ret, ",default=",
398 SummarizeAttrValue(op_def.attr(i).default_value()));
399 }
400 if (op_def.attr(i).has_minimum()) {
401 strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
402 }
403 if (op_def.attr(i).has_allowed_values()) {
404 strings::StrAppend(&ret, ",allowed=",
405 SummarizeAttrValue(op_def.attr(i).allowed_values()));
406 }
407 }
408 if (op_def.is_commutative()) {
409 strings::StrAppend(&ret, "; is_commutative=true");
410 }
411 if (op_def.is_aggregate()) {
412 strings::StrAppend(&ret, "; is_aggregate=true");
413 }
414 if (op_def.is_stateful()) {
415 strings::StrAppend(&ret, "; is_stateful=true");
416 }
417 if (op_def.allows_uninitialized_input()) {
418 strings::StrAppend(&ret, "; allows_uninitialized_input=true");
419 }
420 strings::StrAppend(&ret, ">");
421 return ret;
422 }
423
424 namespace {
425
426 // Returns true if every element of `sub` is contained in `super`.
427 template <class T>
IsSubsetOf(const T & sub,const T & super)428 bool IsSubsetOf(const T& sub, const T& super) {
429 for (const auto& o : sub) {
430 bool found = false;
431 for (const auto& n : super) {
432 if (o == n) {
433 found = true;
434 break;
435 }
436 }
437 if (!found) return false;
438 }
439 return true;
440 }
441
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)442 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
443 const OpDef::AttrDef& new_attr) {
444 // Anything -> no restriction : not more restrictive.
445 if (!new_attr.has_allowed_values()) return false;
446 // No restriction -> restriction : more restrictive.
447 if (!old_attr.has_allowed_values()) return true;
448 // If anything that was previously allowed is no longer allowed:
449 // more restrictive.
450 if (!IsSubsetOf(old_attr.allowed_values().list().type(),
451 new_attr.allowed_values().list().type())) {
452 return true;
453 }
454 if (!IsSubsetOf(old_attr.allowed_values().list().s(),
455 new_attr.allowed_values().list().s())) {
456 return true;
457 }
458 return false;
459 }
460
AllowedStr(const OpDef::AttrDef & attr)461 string AllowedStr(const OpDef::AttrDef& attr) {
462 if (!attr.has_allowed_values()) return "no restriction";
463 return SummarizeAttrValue(attr.allowed_values());
464 }
465
DefaultAttrStr(const OpDef::AttrDef & attr)466 string DefaultAttrStr(const OpDef::AttrDef& attr) {
467 if (!attr.has_default_value()) return "no default";
468 return SummarizeAttrValue(attr.default_value());
469 }
470
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)471 bool HigherMinimum(const OpDef::AttrDef& old_attr,
472 const OpDef::AttrDef& new_attr) {
473 // Anything -> no restriction : not more restrictive.
474 if (!new_attr.has_minimum()) return false;
475 // No restriction -> restriction : more restrictive.
476 if (!old_attr.has_minimum()) return true;
477 // If anything that was previously allowed is no longer allowed:
478 // more restrictive.
479 return new_attr.minimum() > old_attr.minimum();
480 }
481
MinStr(const OpDef::AttrDef & attr)482 string MinStr(const OpDef::AttrDef& attr) {
483 if (!attr.has_minimum()) return "no minimum";
484 return strings::StrCat(attr.minimum());
485 }
486
487 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)488 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
489 for (const auto& attr : op_def.attr()) {
490 (*attr_map)[attr.name()] = &attr;
491 }
492 }
493
494 // Add a comma to *s every call but the first (*add_comma should be
495 // initialized to false).
AddComma(string * s,bool * add_comma)496 void AddComma(string* s, bool* add_comma) {
497 if (*add_comma) {
498 strings::StrAppend(s, ", ");
499 } else {
500 *add_comma = true;
501 }
502 }
503
504 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)505 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
506 if (name) {
507 strings::StrAppend(s, arg.name(), ":");
508 }
509 }
510
511 // Compute a signature for either inputs or outputs that will be the
512 // same for both the old and new OpDef if they are compatible. We
513 // assume that new_attrs is a superset of old_attrs, and that any attr
514 // in the difference has a default. Our strategy is to make a list of
515 // types, where the types are things like:
516 // * "int32", "float", etc.,
517 // * "T" for some attr "T" in old_attrs, or
518 // * "N * type" for "N" either some attr in old_attrs.
519 //
520 // We get the types by either using the attrs in args if they are in
521 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)522 string ComputeArgSignature(
523 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
524 const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
525 bool names) {
526 string s;
527 bool add_comma = false;
528 for (const OpDef::ArgDef& arg : args) {
529 if (!arg.type_list_attr().empty()) {
530 const OpDef::AttrDef* old_attr =
531 gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
532 if (old_attr) {
533 // Both old and new have the list(type) attr, so can use it directly.
534 AddComma(&s, &add_comma);
535 AddName(&s, names, arg);
536 strings::StrAppend(&s, arg.type_list_attr());
537 ref->push_back(arg.is_ref());
538 } else {
539 // Missing the list(type) attr in the old, so use the default
540 // value for the attr from new instead.
541 const OpDef::AttrDef* new_attr =
542 gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
543 const auto& type_list = new_attr->default_value().list().type();
544 if (type_list.empty()) continue;
545 for (int i = 0; i < type_list.size(); ++i) {
546 AddComma(&s, &add_comma);
547 AddName(&s, names, arg);
548 strings::StrAppend(
549 &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
550 ref->push_back(arg.is_ref());
551 }
552 }
553 } else {
554 int num = 1; // How many input/outputs does this represent?
555 string type; // What is the type of this arg?
556 AddName(&type, names, arg);
557 if (!arg.number_attr().empty()) {
558 // N * type case.
559 const OpDef::AttrDef* old_attr =
560 gtl::FindPtrOrNull(old_attrs, arg.number_attr());
561 if (old_attr) {
562 // Both old and new have the number attr, so can use it directly.
563 strings::StrAppend(&type, arg.number_attr(), " * ");
564 } else {
565 // Missing the number attr in the old, so use the default
566 // value for the attr from new instead.
567 const OpDef::AttrDef* new_attr =
568 gtl::FindPtrOrNull(new_attrs, arg.number_attr());
569 num = new_attr->default_value().i();
570 }
571 }
572
573 if (arg.type() != DT_INVALID) {
574 // int32, float, etc. case
575 strings::StrAppend(&type, DataTypeString(arg.type()));
576 } else {
577 const OpDef::AttrDef* old_attr =
578 gtl::FindPtrOrNull(old_attrs, arg.type_attr());
579 if (old_attr) {
580 // Both old and new have the type attr, so can use it directly.
581 strings::StrAppend(&type, arg.type_attr());
582 } else {
583 // Missing the type attr in the old, so use the default
584 // value for the attr from new instead.
585 const OpDef::AttrDef* new_attr =
586 gtl::FindPtrOrNull(new_attrs, arg.type_attr());
587 strings::StrAppend(&type,
588 DataTypeString(new_attr->default_value().type()));
589 }
590 }
591
592 // Record `num` * `type` in the signature.
593 for (int i = 0; i < num; ++i) {
594 AddComma(&s, &add_comma);
595 strings::StrAppend(&s, type);
596 ref->push_back(arg.is_ref());
597 }
598 }
599 }
600
601 return s;
602 }
603
604 } // namespace
605
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)606 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
607 #define VALIDATE(CONDITION, ...) \
608 if (!(CONDITION)) { \
609 return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
610 "; old: ", SummarizeOpDef(old_op), \
611 "; new: ", SummarizeOpDef(new_op)); \
612 }
613
614 VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
615
616 AttrMap new_attrs, old_attrs;
617 FillAttrMap(old_op, &old_attrs);
618 FillAttrMap(new_op, &new_attrs);
619 for (const auto& old_attr : old_op.attr()) {
620 const OpDef::AttrDef* new_attr =
621 gtl::FindPtrOrNull(new_attrs, old_attr.name());
622 VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
623 VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
624 "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
625 "'");
626 VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
627 "' has a stricter set of allowed values; from ",
628 AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
629 VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
630 "' has a higher minimum; from ", MinStr(old_attr), " to ",
631 MinStr(*new_attr));
632 }
633
634 for (const auto& new_attr : new_op.attr()) {
635 const OpDef::AttrDef* old_attr =
636 gtl::FindPtrOrNull(old_attrs, new_attr.name());
637 VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
638 new_attr.name(), "' added without default");
639 }
640
641 std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
642 const string old_in_sig = ComputeArgSignature(
643 old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
644 const string new_in_sig = ComputeArgSignature(
645 new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
646 VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
647 "' vs. '", new_in_sig, "'");
648 VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen
649 "Unexpected change in input ref lists.");
650 for (int i = 0; i < old_in_ref.size(); ++i) {
651 // Allowed to remove "ref" from an input (or leave it unchanged).
652 VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
653 " changed from non-ref to ref");
654 }
655
656 const string old_out_sig =
657 ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
658 &old_out_ref, true /* names */);
659 const string new_out_sig =
660 ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
661 &new_out_ref, true /* names */);
662 VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
663 old_out_sig, "' vs. '", new_out_sig, "'");
664 VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen
665 "Unexpected change in output ref lists");
666 for (int i = 0; i < old_out_ref.size(); ++i) {
667 // Allowed to add "ref" to an output (or leave it unchanged).
668 VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
669 " changed from ref to non-ref");
670 }
671
672 return Status::OK();
673 }
674
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)675 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
676 const OpDef& penultimate_op,
677 const OpDef& new_op) {
678 AttrMap new_attrs, old_attrs;
679 FillAttrMap(old_op, &old_attrs);
680 FillAttrMap(new_op, &new_attrs);
681
682 for (const auto& penultimate_attr : penultimate_op.attr()) {
683 const OpDef::AttrDef* old_attr =
684 gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
685 if (old_attr != nullptr) continue; // attr wasn't added
686 const OpDef::AttrDef* new_attr =
687 gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
688
689 // These shouldn't happen if the op passed OpDefCompatible().
690 if (new_attr == nullptr) {
691 return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
692 "' in op: ", SummarizeOpDef(new_op));
693 }
694 if (!penultimate_attr.has_default_value() ||
695 !new_attr->has_default_value()) {
696 return errors::InvalidArgument("Missing default for attr '",
697 penultimate_attr.name(),
698 "' in op: ", SummarizeOpDef(new_op));
699 }
700
701 // Actually test that the attr's default value hasn't changed.
702 if (!AreAttrValuesEqual(penultimate_attr.default_value(),
703 new_attr->default_value())) {
704 return errors::InvalidArgument(
705 "Can't change default value for attr '", penultimate_attr.name(),
706 "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
707 " in op: ", SummarizeOpDef(new_op));
708 }
709 }
710
711 return Status::OK();
712 }
713
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)714 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
715 AttrMap new_attrs, old_attrs;
716 FillAttrMap(old_op, &old_attrs);
717 FillAttrMap(new_op, &new_attrs);
718
719 for (const auto& old_attr : old_op.attr()) {
720 const OpDef::AttrDef* new_attr =
721 gtl::FindPtrOrNull(new_attrs, old_attr.name());
722 if (new_attr == nullptr) continue;
723 if (old_attr.has_default_value() != new_attr->has_default_value()) {
724 return errors::InvalidArgument(
725 "Attr '", old_attr.name(), "' has added/removed it's default; ",
726 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
727 }
728 if (old_attr.has_default_value() &&
729 !AreAttrValuesEqual(old_attr.default_value(),
730 new_attr->default_value())) {
731 return errors::InvalidArgument(
732 "Attr '", old_attr.name(), "' has changed it's default value; ",
733 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
734 }
735 }
736
737 return Status::OK();
738 }
739
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)740 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
741 for (int i = 0; i < op_def->input_arg_size(); ++i) {
742 op_def->mutable_input_arg(i)->clear_description();
743 }
744 for (int i = 0; i < op_def->output_arg_size(); ++i) {
745 op_def->mutable_output_arg(i)->clear_description();
746 }
747 for (int i = 0; i < op_def->attr_size(); ++i) {
748 op_def->mutable_attr(i)->clear_description();
749 }
750 op_def->clear_summary();
751 op_def->clear_description();
752 }
753
RemoveDescriptionsFromOpDef(OpDef * op_def)754 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
755 RemoveNonDeprecationDescriptionsFromOpDef(op_def);
756 if (op_def->has_deprecation()) {
757 op_def->mutable_deprecation()->clear_explanation();
758 }
759 }
760
RemoveDescriptionsFromOpList(OpList * op_list)761 void RemoveDescriptionsFromOpList(OpList* op_list) {
762 for (int i = 0; i < op_list->op_size(); ++i) {
763 OpDef* op_def = op_list->mutable_op(i);
764 RemoveDescriptionsFromOpDef(op_def);
765 }
766 }
767
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)768 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
769 #ifndef TENSORFLOW_LITE_PROTOS
770 DCHECK_EQ(7, a1.GetDescriptor()->field_count())
771 << "Please modify these equality and hash functions to reflect the "
772 "changes to the AttrDef protobuf";
773 #endif // TENSORFLOW_LITE_PROTOS
774
775 if (a1.name() != a2.name()) return false;
776 if (a1.type() != a2.type()) return false;
777 if (a1.description() != a2.description()) return false;
778 if (a1.has_minimum() != a2.has_minimum()) return false;
779 if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
780 if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
781 if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
782 return false;
783 return true;
784 }
785
AttrDefHash(const OpDef::AttrDef & a)786 uint64 AttrDefHash(const OpDef::AttrDef& a) {
787 uint64 h = Hash64(a.name());
788 h = Hash64(a.type().data(), a.type().size(), h);
789 h = Hash64Combine(AttrValueHash(a.default_value()), h);
790 h = Hash64(a.description().data(), a.description().size(), h);
791 h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
792 h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
793 h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
794 return h;
795 }
796
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)797 bool RepeatedAttrDefEqual(
798 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
799 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
800 std::unordered_map<string, const OpDef::AttrDef*> a1_set;
801 for (const OpDef::AttrDef& def : a1) {
802 DCHECK(a1_set.find(def.name()) == a1_set.end())
803 << "AttrDef names must be unique, but '" << def.name()
804 << "' appears more than once";
805 a1_set[def.name()] = &def;
806 }
807 for (const OpDef::AttrDef& def : a2) {
808 auto iter = a1_set.find(def.name());
809 if (iter == a1_set.end()) return false;
810 if (!AttrDefEqual(*iter->second, def)) return false;
811 a1_set.erase(iter);
812 }
813 if (!a1_set.empty()) return false;
814 return true;
815 }
816
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)817 uint64 RepeatedAttrDefHash(
818 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
819 // Insert AttrDefs into map to deterministically sort by name
820 std::map<string, const OpDef::AttrDef*> a_set;
821 for (const OpDef::AttrDef& def : a) {
822 a_set[def.name()] = &def;
823 }
824 // Iterate and combines hashes of keys and values
825 uint64 h = 0xDECAFCAFFE;
826 for (const auto& pair : a_set) {
827 h = Hash64(pair.first.data(), pair.first.size(), h);
828 h = Hash64Combine(AttrDefHash(*pair.second), h);
829 }
830 return h;
831 }
832
OpDefEqual(const OpDef & o1,const OpDef & o2)833 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
834 // attr order doesn't matter.
835 // Compare it separately here instead of serializing below.
836 if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
837
838 // `control_output` order doesn't matter.
839 std::set<string> control_output1(o1.control_output().begin(),
840 o1.control_output().end());
841 std::set<string> control_output2(o2.control_output().begin(),
842 o2.control_output().end());
843 if (control_output1 != control_output2) return false;
844
845 // Clear `attr` and `control_output` fields, serialize, and compare serialized
846 // strings.
847 OpDef o1_copy = o1;
848 OpDef o2_copy = o2;
849 o1_copy.clear_attr();
850 o1_copy.clear_control_output();
851 o2_copy.clear_attr();
852 o2_copy.clear_control_output();
853
854 return AreSerializedProtosEqual(o1_copy, o2_copy);
855 }
856
OpDefHash(const OpDef & o)857 uint64 OpDefHash(const OpDef& o) {
858 uint64 h = RepeatedAttrDefHash(o.attr());
859
860 // Compute deterministic order-independent control outputs hash.
861 std::set<string> control_output(o.control_output().begin(),
862 o.control_output().end());
863 for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co));
864
865 OpDef o_copy = o;
866 o_copy.clear_attr();
867 o_copy.clear_control_output();
868 return DeterministicProtoHash64(o_copy, h);
869 }
870
871 } // namespace tensorflow
872