1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op_def_util.h"
17
18 #include <set>
19 #include <unordered_map>
20 #include <unordered_set>
21
22 #include "tensorflow/core/framework/attr_value.pb.h"
23 #include "tensorflow/core/framework/attr_value_util.h"
24 #include "tensorflow/core/framework/op_def.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/gtl/map_util.h"
29 #include "tensorflow/core/lib/hash/hash.h"
30 #include "tensorflow/core/lib/strings/proto_serialization.h"
31 #include "tensorflow/core/lib/strings/scanner.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/protobuf.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace tensorflow {
39 namespace { // ------ Helper functions ------
40
HasAttrStyleType(const OpDef::ArgDef & arg)41 bool HasAttrStyleType(const OpDef::ArgDef& arg) {
42 return arg.type() != DT_INVALID || !arg.type_attr().empty() ||
43 !arg.type_list_attr().empty();
44 }
45
AllowedTypeValue(DataType dt,const OpDef::AttrDef & attr)46 Status AllowedTypeValue(DataType dt, const OpDef::AttrDef& attr) {
47 const AttrValue& allowed_values(attr.allowed_values());
48 for (auto allowed : allowed_values.list().type()) {
49 if (dt == allowed) {
50 return Status::OK();
51 }
52 }
53 string allowed_str;
54 for (int i = 0; i < allowed_values.list().type_size(); ++i) {
55 if (!allowed_str.empty()) {
56 strings::StrAppend(&allowed_str, ", ");
57 }
58 strings::StrAppend(&allowed_str,
59 DataTypeString(allowed_values.list().type(i)));
60 }
61 return errors::InvalidArgument(
62 "Value for attr '", attr.name(), "' of ", DataTypeString(dt),
63 " is not in the list of allowed values: ", allowed_str);
64 }
65
AllowedStringValue(const string & str,const OpDef::AttrDef & attr)66 Status AllowedStringValue(const string& str, const OpDef::AttrDef& attr) {
67 const AttrValue& allowed_values(attr.allowed_values());
68 for (const auto& allowed : allowed_values.list().s()) {
69 if (str == allowed) {
70 return Status::OK();
71 }
72 }
73 string allowed_str;
74 for (const string& allowed : allowed_values.list().s()) {
75 if (!allowed_str.empty()) {
76 strings::StrAppend(&allowed_str, ", ");
77 }
78 strings::StrAppend(&allowed_str, "\"", allowed, "\"");
79 }
80 return errors::InvalidArgument(
81 "Value for attr '", attr.name(), "' of \"", str,
82 "\" is not in the list of allowed values: ", allowed_str);
83 }
84
85 } // namespace
86
87 // Requires: attr has already been validated.
ValidateAttrValue(const AttrValue & attr_value,const OpDef::AttrDef & attr)88 Status ValidateAttrValue(const AttrValue& attr_value,
89 const OpDef::AttrDef& attr) {
90 // Is it a valid value?
91 TF_RETURN_WITH_CONTEXT_IF_ERROR(AttrValueHasType(attr_value, attr.type()),
92 " for attr '", attr.name(), "'");
93
94 // Does the value satisfy the minimum constraint in the AttrDef?
95 if (attr.has_minimum()) {
96 if (attr.type() == "int") {
97 if (attr_value.i() < attr.minimum()) {
98 return errors::InvalidArgument(
99 "Value for attr '", attr.name(), "' of ", attr_value.i(),
100 " must be at least minimum ", attr.minimum());
101 }
102 } else {
103 int length = -1;
104 if (attr.type() == "list(string)") {
105 length = attr_value.list().s_size();
106 } else if (attr.type() == "list(int)") {
107 length = attr_value.list().i_size();
108 } else if (attr.type() == "list(float)") {
109 length = attr_value.list().f_size();
110 } else if (attr.type() == "list(bool)") {
111 length = attr_value.list().b_size();
112 } else if (attr.type() == "list(type)") {
113 length = attr_value.list().type_size();
114 } else if (attr.type() == "list(shape)") {
115 length = attr_value.list().shape_size();
116 } else if (attr.type() == "list(tensor)") {
117 length = attr_value.list().tensor_size();
118 } else if (attr.type() == "list(func)") {
119 length = attr_value.list().func_size();
120 }
121 if (length < attr.minimum()) {
122 return errors::InvalidArgument(
123 "Length for attr '", attr.name(), "' of ", length,
124 " must be at least minimum ", attr.minimum());
125 }
126 }
127 }
128
129 // Does the value satisfy the allowed_value constraint in the AttrDef?
130 if (attr.has_allowed_values()) {
131 if (attr.type() == "type") {
132 TF_RETURN_IF_ERROR(AllowedTypeValue(attr_value.type(), attr));
133 } else if (attr.type() == "list(type)") {
134 for (int dt : attr_value.list().type()) {
135 TF_RETURN_IF_ERROR(AllowedTypeValue(static_cast<DataType>(dt), attr));
136 }
137 } else if (attr.type() == "string") {
138 TF_RETURN_IF_ERROR(AllowedStringValue(attr_value.s(), attr));
139 } else if (attr.type() == "list(string)") {
140 for (const string& str : attr_value.list().s()) {
141 TF_RETURN_IF_ERROR(AllowedStringValue(str, attr));
142 }
143 } else {
144 return errors::Unimplemented(
145 "Support for allowed_values not implemented for type ", attr.type());
146 }
147 }
148 return Status::OK();
149 }
150
FindAttr(StringPiece name,const OpDef & op_def)151 const OpDef::AttrDef* FindAttr(StringPiece name, const OpDef& op_def) {
152 for (int i = 0; i < op_def.attr_size(); ++i) {
153 if (op_def.attr(i).name() == name) {
154 return &op_def.attr(i);
155 }
156 }
157 return nullptr;
158 }
159
FindAttrMutable(StringPiece name,OpDef * op_def)160 OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def) {
161 for (int i = 0; i < op_def->attr_size(); ++i) {
162 if (op_def->attr(i).name() == name) {
163 return op_def->mutable_attr(i);
164 }
165 }
166 return nullptr;
167 }
168
FindInputArg(StringPiece name,const OpDef & op_def)169 const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
170 for (int i = 0; i < op_def.input_arg_size(); ++i) {
171 if (op_def.input_arg(i).name() == name) {
172 return &op_def.input_arg(i);
173 }
174 }
175 return nullptr;
176 }
177
FindInputArg(StringPiece name,const ApiDef & api_def)178 const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
179 for (int i = 0; i < api_def.in_arg_size(); ++i) {
180 if (api_def.in_arg(i).name() == name) {
181 return &api_def.in_arg(i);
182 }
183 }
184 return nullptr;
185 }
186
187 #define VALIDATE(EXPR, ...) \
188 do { \
189 if (!(EXPR)) { \
190 return errors::InvalidArgument( \
191 __VA_ARGS__, "; in OpDef: ", op_def.ShortDebugString()); \
192 } \
193 } while (false)
194
ValidateArg(const OpDef::ArgDef & arg,const OpDef & op_def,bool output,std::set<string> * names)195 static Status ValidateArg(const OpDef::ArgDef& arg, const OpDef& op_def,
196 bool output, std::set<string>* names) {
197 const string suffix = strings::StrCat(
198 output ? " for output '" : " for input '", arg.name(), "'");
199 VALIDATE(gtl::InsertIfNotPresent(names, arg.name()),
200 "Duplicate name: ", arg.name());
201 VALIDATE(HasAttrStyleType(arg), "Missing type", suffix);
202
203 if (!arg.number_attr().empty()) {
204 const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
205 VALIDATE(attr != nullptr, "No attr with name '", arg.number_attr(), "'",
206 suffix);
207 VALIDATE(attr->type() == "int", "Attr '", attr->name(), "' used as length",
208 suffix, " has type ", attr->type(), " != int");
209 VALIDATE(attr->has_minimum(), "Attr '", attr->name(), "' used as length",
210 suffix, " must have minimum");
211 VALIDATE(attr->minimum() >= 0, "Attr '", attr->name(), "' used as length",
212 suffix, " must have minimum >= 0");
213 VALIDATE(arg.type_list_attr().empty(),
214 "Can't have both number_attr and type_list_attr", suffix);
215 VALIDATE((arg.type() != DT_INVALID ? 1 : 0) +
216 (!arg.type_attr().empty() ? 1 : 0) ==
217 1,
218 "Exactly one of type, type_attr must be set", suffix);
219 } else {
220 const int num_type_fields = (arg.type() != DT_INVALID ? 1 : 0) +
221 (!arg.type_attr().empty() ? 1 : 0) +
222 (!arg.type_list_attr().empty() ? 1 : 0);
223 VALIDATE(num_type_fields == 1,
224 "Exactly one of type, type_attr, type_list_attr must be set",
225 suffix);
226 }
227
228 if (!arg.type_attr().empty()) {
229 const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
230 VALIDATE(attr != nullptr, "No attr with name '", arg.type_attr(), "'",
231 suffix);
232 VALIDATE(attr->type() == "type", "Attr '", attr->name(),
233 "' used as type_attr", suffix, " has type ", attr->type(),
234 " != type");
235 } else if (!arg.type_list_attr().empty()) {
236 const OpDef::AttrDef* attr = FindAttr(arg.type_list_attr(), op_def);
237 VALIDATE(attr != nullptr, "No attr with name '", arg.type_list_attr(), "'",
238 suffix);
239 VALIDATE(attr->type() == "list(type)", "Attr '", attr->name(),
240 "' used as type_list_attr", suffix, " has type ", attr->type(),
241 " != list(type)");
242 } else {
243 // All argument types should be non-reference types at this point.
244 // ArgDef.is_ref is set to true for reference arguments.
245 VALIDATE(!IsRefType(arg.type()), "Illegal use of ref type '",
246 DataTypeString(arg.type()), "'. Use 'Ref(type)' instead", suffix);
247 }
248
249 return Status::OK();
250 }
251
IsValidOpName(StringPiece sp)252 bool IsValidOpName(StringPiece sp) {
253 using ::tensorflow::strings::Scanner;
254
255 Scanner scanner(sp);
256 scanner.One(Scanner::UPPERLETTER).Any(Scanner::LETTER_DIGIT_UNDERSCORE);
257
258 while (true) {
259 if (!scanner.GetResult()) // Some error in previous iteration.
260 return false;
261 if (scanner.empty()) // No error, but nothing left, good.
262 return true;
263
264 // Absorb another name/namespace, starting with a '>'
265 scanner.One(Scanner::RANGLE)
266 .One(Scanner::UPPERLETTER)
267 .Any(Scanner::LETTER_DIGIT_UNDERSCORE);
268 }
269 }
270
ValidateOpDef(const OpDef & op_def)271 Status ValidateOpDef(const OpDef& op_def) {
272 if (!absl::StartsWith(op_def.name(), "_")) {
273 VALIDATE(IsValidOpName(op_def.name()), "Invalid name: ", op_def.name(),
274 " (Did you use CamelCase?)");
275 }
276
277 std::set<string> names; // for detecting duplicate names
278 for (const auto& attr : op_def.attr()) {
279 // Validate name
280 VALIDATE(gtl::InsertIfNotPresent(&names, attr.name()),
281 "Duplicate name: ", attr.name());
282 DataType dt;
283 VALIDATE(!DataTypeFromString(attr.name(), &dt), "Attr can't have name ",
284 attr.name(), " that matches a data type");
285
286 // Validate type
287 StringPiece type(attr.type());
288 bool is_list = absl::ConsumePrefix(&type, "list(");
289 bool found = false;
290 for (StringPiece valid : {"string", "int", "float", "bool", "type", "shape",
291 "tensor", "func"}) {
292 if (absl::ConsumePrefix(&type, valid)) {
293 found = true;
294 break;
295 }
296 }
297 VALIDATE(found, "Unrecognized type '", type, "' in attr '", attr.name(),
298 "'");
299 if (is_list) {
300 VALIDATE(absl::ConsumePrefix(&type, ")"),
301 "'list(' is missing ')' in attr ", attr.name(), "'s type ",
302 attr.type());
303 }
304 VALIDATE(type.empty(), "Extra '", type, "' at the end of attr ",
305 attr.name(), "'s type ", attr.type());
306
307 // Validate minimum
308 if (attr.has_minimum()) {
309 VALIDATE(attr.type() == "int" || is_list, "Attr '", attr.name(),
310 "' has minimum for unsupported type ", attr.type());
311 if (is_list) {
312 VALIDATE(attr.minimum() >= 0, "Attr '", attr.name(),
313 "' with list type must have a non-negative minimum, not ",
314 attr.minimum());
315 }
316 } else {
317 VALIDATE(attr.minimum() == 0, "Attr '", attr.name(),
318 "' with has_minimum = false but minimum ", attr.minimum(),
319 " not equal to default of 0");
320 }
321
322 // Validate allowed_values
323 if (attr.has_allowed_values()) {
324 const string list_type =
325 is_list ? attr.type() : strings::StrCat("list(", attr.type(), ")");
326 TF_RETURN_WITH_CONTEXT_IF_ERROR(
327 AttrValueHasType(attr.allowed_values(), list_type), " for attr '",
328 attr.name(), "' in Op '", op_def.name(), "'");
329 }
330
331 // Validate default_value (after we have validated the rest of the attr,
332 // so we can use ValidateAttrValue()).
333 if (attr.has_default_value()) {
334 TF_RETURN_WITH_CONTEXT_IF_ERROR(
335 ValidateAttrValue(attr.default_value(), attr), " in Op '",
336 op_def.name(), "'");
337 }
338 }
339
340 for (const auto& arg : op_def.input_arg()) {
341 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, false, &names));
342 }
343
344 for (const auto& arg : op_def.output_arg()) {
345 TF_RETURN_IF_ERROR(ValidateArg(arg, op_def, true, &names));
346 }
347
348 return Status::OK();
349 }
350
351 #undef VALIDATE
352
CheckOpDeprecation(const OpDef & op_def,int graph_def_version)353 Status CheckOpDeprecation(const OpDef& op_def, int graph_def_version) {
354 if (op_def.has_deprecation()) {
355 const OpDeprecation& dep = op_def.deprecation();
356 if (graph_def_version >= dep.version()) {
357 return errors::Unimplemented(
358 "Op ", op_def.name(), " is not available in GraphDef version ",
359 graph_def_version, ". It has been removed in version ", dep.version(),
360 ". ", dep.explanation(), ".");
361 } else {
362 // Warn only once for each op name, and do it in a threadsafe manner.
363 static mutex mu(LINKER_INITIALIZED);
364 static std::unordered_set<string> warned;
365 bool warn;
366 {
367 mutex_lock lock(mu);
368 warn = warned.insert(op_def.name()).second;
369 }
370 if (warn) {
371 LOG(WARNING) << "Op " << op_def.name() << " is deprecated."
372 << " It will cease to work in GraphDef version "
373 << dep.version() << ". " << dep.explanation() << ".";
374 }
375 }
376 }
377 return Status::OK();
378 }
379
380 namespace {
381
SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args)382 string SummarizeArgs(const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
383 string ret;
384 for (const OpDef::ArgDef& arg : args) {
385 if (!ret.empty()) strings::StrAppend(&ret, ", ");
386 strings::StrAppend(&ret, arg.name(), ":");
387 if (arg.is_ref()) strings::StrAppend(&ret, "Ref(");
388 if (!arg.number_attr().empty()) {
389 strings::StrAppend(&ret, arg.number_attr(), "*");
390 }
391 if (arg.type() != DT_INVALID) {
392 strings::StrAppend(&ret, DataTypeString(arg.type()));
393 } else {
394 strings::StrAppend(&ret, arg.type_attr());
395 }
396 if (arg.is_ref()) strings::StrAppend(&ret, ")");
397 }
398 return ret;
399 }
400
401 } // namespace
402
SummarizeOpDef(const OpDef & op_def)403 string SummarizeOpDef(const OpDef& op_def) {
404 string ret = strings::StrCat("Op<name=", op_def.name());
405 strings::StrAppend(&ret, "; signature=", SummarizeArgs(op_def.input_arg()),
406 " -> ", SummarizeArgs(op_def.output_arg()));
407 for (int i = 0; i < op_def.attr_size(); ++i) {
408 strings::StrAppend(&ret, "; attr=", op_def.attr(i).name(), ":",
409 op_def.attr(i).type());
410 if (op_def.attr(i).has_default_value()) {
411 strings::StrAppend(&ret, ",default=",
412 SummarizeAttrValue(op_def.attr(i).default_value()));
413 }
414 if (op_def.attr(i).has_minimum()) {
415 strings::StrAppend(&ret, ",min=", op_def.attr(i).minimum());
416 }
417 if (op_def.attr(i).has_allowed_values()) {
418 strings::StrAppend(&ret, ",allowed=",
419 SummarizeAttrValue(op_def.attr(i).allowed_values()));
420 }
421 }
422 if (op_def.is_commutative()) {
423 strings::StrAppend(&ret, "; is_commutative=true");
424 }
425 if (op_def.is_aggregate()) {
426 strings::StrAppend(&ret, "; is_aggregate=true");
427 }
428 if (op_def.is_stateful()) {
429 strings::StrAppend(&ret, "; is_stateful=true");
430 }
431 if (op_def.allows_uninitialized_input()) {
432 strings::StrAppend(&ret, "; allows_uninitialized_input=true");
433 }
434 strings::StrAppend(&ret, ">");
435 return ret;
436 }
437
438 namespace {
439
440 // Returns true if every element of `sub` is contained in `super`.
441 template <class T>
IsSubsetOf(const T & sub,const T & super)442 bool IsSubsetOf(const T& sub, const T& super) {
443 for (const auto& o : sub) {
444 bool found = false;
445 for (const auto& n : super) {
446 if (o == n) {
447 found = true;
448 break;
449 }
450 }
451 if (!found) return false;
452 }
453 return true;
454 }
455
MoreRestrictive(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)456 bool MoreRestrictive(const OpDef::AttrDef& old_attr,
457 const OpDef::AttrDef& new_attr) {
458 // Anything -> no restriction : not more restrictive.
459 if (!new_attr.has_allowed_values()) return false;
460 // No restriction -> restriction : more restrictive.
461 if (!old_attr.has_allowed_values()) return true;
462 // If anything that was previously allowed is no longer allowed:
463 // more restrictive.
464 if (!IsSubsetOf(old_attr.allowed_values().list().type(),
465 new_attr.allowed_values().list().type())) {
466 return true;
467 }
468 if (!IsSubsetOf(old_attr.allowed_values().list().s(),
469 new_attr.allowed_values().list().s())) {
470 return true;
471 }
472 return false;
473 }
474
AllowedStr(const OpDef::AttrDef & attr)475 string AllowedStr(const OpDef::AttrDef& attr) {
476 if (!attr.has_allowed_values()) return "no restriction";
477 return SummarizeAttrValue(attr.allowed_values());
478 }
479
DefaultAttrStr(const OpDef::AttrDef & attr)480 string DefaultAttrStr(const OpDef::AttrDef& attr) {
481 if (!attr.has_default_value()) return "no default";
482 return SummarizeAttrValue(attr.default_value());
483 }
484
HigherMinimum(const OpDef::AttrDef & old_attr,const OpDef::AttrDef & new_attr)485 bool HigherMinimum(const OpDef::AttrDef& old_attr,
486 const OpDef::AttrDef& new_attr) {
487 // Anything -> no restriction : not more restrictive.
488 if (!new_attr.has_minimum()) return false;
489 // No restriction -> restriction : more restrictive.
490 if (!old_attr.has_minimum()) return true;
491 // If anything that was previously allowed is no longer allowed:
492 // more restrictive.
493 return new_attr.minimum() > old_attr.minimum();
494 }
495
MinStr(const OpDef::AttrDef & attr)496 string MinStr(const OpDef::AttrDef& attr) {
497 if (!attr.has_minimum()) return "no minimum";
498 return strings::StrCat(attr.minimum());
499 }
500
501 typedef std::unordered_map<string, const OpDef::AttrDef*> AttrMap;
FillAttrMap(const OpDef & op_def,AttrMap * attr_map)502 void FillAttrMap(const OpDef& op_def, AttrMap* attr_map) {
503 for (const auto& attr : op_def.attr()) {
504 (*attr_map)[attr.name()] = &attr;
505 }
506 }
507
508 // Add a comma to *s every call but the first (*add_comma should be
509 // initialized to false).
AddComma(string * s,bool * add_comma)510 void AddComma(string* s, bool* add_comma) {
511 if (*add_comma) {
512 strings::StrAppend(s, ", ");
513 } else {
514 *add_comma = true;
515 }
516 }
517
518 // Will add the `name` from arg if name is true.
AddName(string * s,bool name,const OpDef::ArgDef & arg)519 void AddName(string* s, bool name, const OpDef::ArgDef& arg) {
520 if (name) {
521 strings::StrAppend(s, arg.name(), ":");
522 }
523 }
524
525 // Compute a signature for either inputs or outputs that will be the
526 // same for both the old and new OpDef if they are compatible. We
527 // assume that new_attrs is a superset of old_attrs, and that any attr
528 // in the difference has a default. Our strategy is to make a list of
529 // types, where the types are things like:
530 // * "int32", "float", etc.,
531 // * "T" for some attr "T" in old_attrs, or
532 // * "N * type" for "N" either some attr in old_attrs.
533 //
534 // We get the types by either using the attrs in args if they are in
535 // old_attrs, or substituting the default value from new_attrs.
ComputeArgSignature(const protobuf::RepeatedPtrField<OpDef::ArgDef> & args,const AttrMap & old_attrs,const AttrMap & new_attrs,std::vector<bool> * ref,bool names)536 string ComputeArgSignature(
537 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
538 const AttrMap& old_attrs, const AttrMap& new_attrs, std::vector<bool>* ref,
539 bool names) {
540 string s;
541 bool add_comma = false;
542 for (const OpDef::ArgDef& arg : args) {
543 if (!arg.type_list_attr().empty()) {
544 const OpDef::AttrDef* old_attr =
545 gtl::FindPtrOrNull(old_attrs, arg.type_list_attr());
546 if (old_attr) {
547 // Both old and new have the list(type) attr, so can use it directly.
548 AddComma(&s, &add_comma);
549 AddName(&s, names, arg);
550 strings::StrAppend(&s, arg.type_list_attr());
551 ref->push_back(arg.is_ref());
552 } else {
553 // Missing the list(type) attr in the old, so use the default
554 // value for the attr from new instead.
555 const OpDef::AttrDef* new_attr =
556 gtl::FindPtrOrNull(new_attrs, arg.type_list_attr());
557 const auto& type_list = new_attr->default_value().list().type();
558 if (type_list.empty()) continue;
559 for (int i = 0; i < type_list.size(); ++i) {
560 AddComma(&s, &add_comma);
561 AddName(&s, names, arg);
562 strings::StrAppend(
563 &s, DataTypeString(static_cast<DataType>(type_list.Get(i))));
564 ref->push_back(arg.is_ref());
565 }
566 }
567 } else {
568 int num = 1; // How many input/outputs does this represent?
569 string type; // What is the type of this arg?
570 AddName(&type, names, arg);
571 if (!arg.number_attr().empty()) {
572 // N * type case.
573 const OpDef::AttrDef* old_attr =
574 gtl::FindPtrOrNull(old_attrs, arg.number_attr());
575 if (old_attr) {
576 // Both old and new have the number attr, so can use it directly.
577 strings::StrAppend(&type, arg.number_attr(), " * ");
578 } else {
579 // Missing the number attr in the old, so use the default
580 // value for the attr from new instead.
581 const OpDef::AttrDef* new_attr =
582 gtl::FindPtrOrNull(new_attrs, arg.number_attr());
583 num = new_attr->default_value().i();
584 }
585 }
586
587 if (arg.type() != DT_INVALID) {
588 // int32, float, etc. case
589 strings::StrAppend(&type, DataTypeString(arg.type()));
590 } else {
591 const OpDef::AttrDef* old_attr =
592 gtl::FindPtrOrNull(old_attrs, arg.type_attr());
593 if (old_attr) {
594 // Both old and new have the type attr, so can use it directly.
595 strings::StrAppend(&type, arg.type_attr());
596 } else {
597 // Missing the type attr in the old, so use the default
598 // value for the attr from new instead.
599 const OpDef::AttrDef* new_attr =
600 gtl::FindPtrOrNull(new_attrs, arg.type_attr());
601 strings::StrAppend(&type,
602 DataTypeString(new_attr->default_value().type()));
603 }
604 }
605
606 // Record `num` * `type` in the signature.
607 for (int i = 0; i < num; ++i) {
608 AddComma(&s, &add_comma);
609 strings::StrAppend(&s, type);
610 ref->push_back(arg.is_ref());
611 }
612 }
613 }
614
615 return s;
616 }
617
618 } // namespace
619
OpDefCompatible(const OpDef & old_op,const OpDef & new_op)620 Status OpDefCompatible(const OpDef& old_op, const OpDef& new_op) {
621 #define VALIDATE(CONDITION, ...) \
622 if (!(CONDITION)) { \
623 return errors::InvalidArgument("Incompatible Op change: ", __VA_ARGS__, \
624 "; old: ", SummarizeOpDef(old_op), \
625 "; new: ", SummarizeOpDef(new_op)); \
626 }
627
628 VALIDATE(old_op.name() == new_op.name(), "Name mismatch");
629
630 AttrMap new_attrs, old_attrs;
631 FillAttrMap(old_op, &old_attrs);
632 FillAttrMap(new_op, &new_attrs);
633 for (const auto& old_attr : old_op.attr()) {
634 const OpDef::AttrDef* new_attr =
635 gtl::FindPtrOrNull(new_attrs, old_attr.name());
636 VALIDATE(new_attr != nullptr, "Attr '", old_attr.name(), "' removed");
637 VALIDATE(old_attr.type() == new_attr->type(), "Attr '", old_attr.name(),
638 "' changed type '", old_attr.type(), "' -> '", new_attr->type(),
639 "'");
640 VALIDATE(!MoreRestrictive(old_attr, *new_attr), "Attr '", old_attr.name(),
641 "' has a stricter set of allowed values; from ",
642 AllowedStr(old_attr), " to ", AllowedStr(*new_attr));
643 VALIDATE(!HigherMinimum(old_attr, *new_attr), "Attr '", old_attr.name(),
644 "' has a higher minimum; from ", MinStr(old_attr), " to ",
645 MinStr(*new_attr));
646 }
647
648 for (const auto& new_attr : new_op.attr()) {
649 const OpDef::AttrDef* old_attr =
650 gtl::FindPtrOrNull(old_attrs, new_attr.name());
651 VALIDATE(old_attr != nullptr || new_attr.has_default_value(), "Attr '",
652 new_attr.name(), "' added without default");
653 }
654
655 std::vector<bool> old_in_ref, new_in_ref, old_out_ref, new_out_ref;
656 const string old_in_sig = ComputeArgSignature(
657 old_op.input_arg(), old_attrs, new_attrs, &old_in_ref, false /* names */);
658 const string new_in_sig = ComputeArgSignature(
659 new_op.input_arg(), old_attrs, new_attrs, &new_in_ref, false /* names */);
660 VALIDATE(old_in_sig == new_in_sig, "Input signature mismatch '", old_in_sig,
661 "' vs. '", new_in_sig, "'");
662 VALIDATE(old_in_ref.size() == new_in_ref.size(), // Should not happen
663 "Unexpected change in input ref lists.");
664 for (int i = 0, end = old_in_ref.size(); i < end; ++i) {
665 // Allowed to remove "ref" from an input (or leave it unchanged).
666 VALIDATE(old_in_ref[i] || !new_in_ref[i], "Input ", i,
667 " changed from non-ref to ref");
668 }
669
670 const string old_out_sig =
671 ComputeArgSignature(old_op.output_arg(), old_attrs, new_attrs,
672 &old_out_ref, true /* names */);
673 const string new_out_sig =
674 ComputeArgSignature(new_op.output_arg(), old_attrs, new_attrs,
675 &new_out_ref, true /* names */);
676 VALIDATE(old_out_sig == new_out_sig, "Output signature mismatch '",
677 old_out_sig, "' vs. '", new_out_sig, "'");
678 VALIDATE(old_out_ref.size() == new_out_ref.size(), // Should not happen
679 "Unexpected change in output ref lists");
680 for (int i = 0, end = old_out_ref.size(); i < end; ++i) {
681 // Allowed to add "ref" to an output (or leave it unchanged).
682 VALIDATE(!old_out_ref[i] || new_out_ref[i], "Output ", i,
683 " changed from ref to non-ref");
684 }
685
686 return Status::OK();
687 }
688
OpDefAddedDefaultsUnchanged(const OpDef & old_op,const OpDef & penultimate_op,const OpDef & new_op)689 Status OpDefAddedDefaultsUnchanged(const OpDef& old_op,
690 const OpDef& penultimate_op,
691 const OpDef& new_op) {
692 AttrMap new_attrs, old_attrs;
693 FillAttrMap(old_op, &old_attrs);
694 FillAttrMap(new_op, &new_attrs);
695
696 for (const auto& penultimate_attr : penultimate_op.attr()) {
697 const OpDef::AttrDef* old_attr =
698 gtl::FindPtrOrNull(old_attrs, penultimate_attr.name());
699 if (old_attr != nullptr) continue; // attr wasn't added
700 const OpDef::AttrDef* new_attr =
701 gtl::FindPtrOrNull(new_attrs, penultimate_attr.name());
702
703 // These shouldn't happen if the op passed OpDefCompatible().
704 if (new_attr == nullptr) {
705 return errors::InvalidArgument("Missing attr '", penultimate_attr.name(),
706 "' in op: ", SummarizeOpDef(new_op));
707 }
708 if (!penultimate_attr.has_default_value() ||
709 !new_attr->has_default_value()) {
710 return errors::InvalidArgument("Missing default for attr '",
711 penultimate_attr.name(),
712 "' in op: ", SummarizeOpDef(new_op));
713 }
714
715 // Actually test that the attr's default value hasn't changed.
716 if (!AreAttrValuesEqual(penultimate_attr.default_value(),
717 new_attr->default_value())) {
718 return errors::InvalidArgument(
719 "Can't change default value for attr '", penultimate_attr.name(),
720 "' from ", SummarizeAttrValue(penultimate_attr.default_value()),
721 " in op: ", SummarizeOpDef(new_op));
722 }
723 }
724
725 return Status::OK();
726 }
727
OpDefAttrDefaultsUnchanged(const OpDef & old_op,const OpDef & new_op)728 Status OpDefAttrDefaultsUnchanged(const OpDef& old_op, const OpDef& new_op) {
729 AttrMap new_attrs, old_attrs;
730 FillAttrMap(old_op, &old_attrs);
731 FillAttrMap(new_op, &new_attrs);
732
733 for (const auto& old_attr : old_op.attr()) {
734 const OpDef::AttrDef* new_attr =
735 gtl::FindPtrOrNull(new_attrs, old_attr.name());
736 if (new_attr == nullptr) continue;
737 if (new_attr->has_default_value() && !old_attr.has_default_value()) {
738 continue; // Adding new default values is safe.
739 }
740 if (old_attr.has_default_value() && !new_attr->has_default_value()) {
741 return errors::InvalidArgument(
742 "Attr '", old_attr.name(), "' has removed it's default; ", "from ",
743 DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
744 }
745 if (old_attr.has_default_value() &&
746 !AreAttrValuesEqual(old_attr.default_value(),
747 new_attr->default_value())) {
748 return errors::InvalidArgument(
749 "Attr '", old_attr.name(), "' has changed it's default value; ",
750 "from ", DefaultAttrStr(old_attr), " to ", DefaultAttrStr(*new_attr));
751 }
752 }
753
754 return Status::OK();
755 }
756
RemoveNonDeprecationDescriptionsFromOpDef(OpDef * op_def)757 void RemoveNonDeprecationDescriptionsFromOpDef(OpDef* op_def) {
758 for (int i = 0; i < op_def->input_arg_size(); ++i) {
759 op_def->mutable_input_arg(i)->clear_description();
760 }
761 for (int i = 0; i < op_def->output_arg_size(); ++i) {
762 op_def->mutable_output_arg(i)->clear_description();
763 }
764 for (int i = 0; i < op_def->attr_size(); ++i) {
765 op_def->mutable_attr(i)->clear_description();
766 }
767 op_def->clear_summary();
768 op_def->clear_description();
769 }
770
RemoveDescriptionsFromOpDef(OpDef * op_def)771 void RemoveDescriptionsFromOpDef(OpDef* op_def) {
772 RemoveNonDeprecationDescriptionsFromOpDef(op_def);
773 if (op_def->has_deprecation()) {
774 op_def->mutable_deprecation()->clear_explanation();
775 }
776 }
777
RemoveDescriptionsFromOpList(OpList * op_list)778 void RemoveDescriptionsFromOpList(OpList* op_list) {
779 for (int i = 0; i < op_list->op_size(); ++i) {
780 OpDef* op_def = op_list->mutable_op(i);
781 RemoveDescriptionsFromOpDef(op_def);
782 }
783 }
784
AttrDefEqual(const OpDef::AttrDef & a1,const OpDef::AttrDef & a2)785 bool AttrDefEqual(const OpDef::AttrDef& a1, const OpDef::AttrDef& a2) {
786 if (std::is_base_of<protobuf::Message, OpDef::AttrDef>()) {
787 DCHECK_EQ(7, reinterpret_cast<const protobuf::Message*>(&a1)
788 ->GetDescriptor()
789 ->field_count())
790 << "Please modify these equality and hash functions to reflect the "
791 "changes to the AttrDef protobuf";
792 }
793
794 if (a1.name() != a2.name()) return false;
795 if (a1.type() != a2.type()) return false;
796 if (a1.description() != a2.description()) return false;
797 if (a1.has_minimum() != a2.has_minimum()) return false;
798 if (a1.has_minimum() && a1.minimum() != a2.minimum()) return false;
799 if (!AreAttrValuesEqual(a1.default_value(), a2.default_value())) return false;
800 if (!AreAttrValuesEqual(a1.allowed_values(), a2.allowed_values()))
801 return false;
802 return true;
803 }
804
AttrDefHash(const OpDef::AttrDef & a)805 uint64 AttrDefHash(const OpDef::AttrDef& a) {
806 uint64 h = Hash64(a.name());
807 h = Hash64(a.type().data(), a.type().size(), h);
808 h = Hash64Combine(AttrValueHash(a.default_value()), h);
809 h = Hash64(a.description().data(), a.description().size(), h);
810 h = Hash64Combine(static_cast<uint64>(a.has_minimum()), h);
811 h = Hash64Combine(static_cast<uint64>(a.minimum()), h);
812 h = Hash64Combine(AttrValueHash(a.allowed_values()), h);
813 return h;
814 }
815
RepeatedAttrDefEqual(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a1,const protobuf::RepeatedPtrField<OpDef::AttrDef> & a2)816 bool RepeatedAttrDefEqual(
817 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a1,
818 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a2) {
819 std::unordered_map<string, const OpDef::AttrDef*> a1_set;
820 for (const OpDef::AttrDef& def : a1) {
821 DCHECK(a1_set.find(def.name()) == a1_set.end())
822 << "AttrDef names must be unique, but '" << def.name()
823 << "' appears more than once";
824 a1_set[def.name()] = &def;
825 }
826 for (const OpDef::AttrDef& def : a2) {
827 auto iter = a1_set.find(def.name());
828 if (iter == a1_set.end()) return false;
829 if (!AttrDefEqual(*iter->second, def)) return false;
830 a1_set.erase(iter);
831 }
832 if (!a1_set.empty()) return false;
833 return true;
834 }
835
RepeatedAttrDefHash(const protobuf::RepeatedPtrField<OpDef::AttrDef> & a)836 uint64 RepeatedAttrDefHash(
837 const protobuf::RepeatedPtrField<OpDef::AttrDef>& a) {
838 // Insert AttrDefs into map to deterministically sort by name
839 std::map<string, const OpDef::AttrDef*> a_set;
840 for (const OpDef::AttrDef& def : a) {
841 a_set[def.name()] = &def;
842 }
843 // Iterate and combines hashes of keys and values
844 uint64 h = 0xDECAFCAFFE;
845 for (const auto& pair : a_set) {
846 h = Hash64(pair.first.data(), pair.first.size(), h);
847 h = Hash64Combine(AttrDefHash(*pair.second), h);
848 }
849 return h;
850 }
851
OpDefEqual(const OpDef & o1,const OpDef & o2)852 bool OpDefEqual(const OpDef& o1, const OpDef& o2) {
853 // attr order doesn't matter.
854 // Compare it separately here instead of serializing below.
855 if (!RepeatedAttrDefEqual(o1.attr(), o2.attr())) return false;
856
857 // `control_output` order doesn't matter.
858 std::set<string> control_output1(o1.control_output().begin(),
859 o1.control_output().end());
860 std::set<string> control_output2(o2.control_output().begin(),
861 o2.control_output().end());
862 if (control_output1 != control_output2) return false;
863
864 // Clear `attr` and `control_output` fields, serialize, and compare serialized
865 // strings.
866 OpDef o1_copy = o1;
867 OpDef o2_copy = o2;
868 o1_copy.clear_attr();
869 o1_copy.clear_control_output();
870 o2_copy.clear_attr();
871 o2_copy.clear_control_output();
872
873 return AreSerializedProtosEqual(o1_copy, o2_copy);
874 }
875
OpDefHash(const OpDef & o)876 uint64 OpDefHash(const OpDef& o) {
877 uint64 h = RepeatedAttrDefHash(o.attr());
878
879 // Compute deterministic order-independent control outputs hash.
880 std::set<string> control_output(o.control_output().begin(),
881 o.control_output().end());
882 for (const auto& co : control_output) h = Hash64Combine(h, Hash64(co));
883
884 OpDef o_copy = o;
885 o_copy.clear_attr();
886 o_copy.clear_control_output();
887 return DeterministicProtoHash64(o_copy, h);
888 }
889
890 } // namespace tensorflow
891