• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/attr_value_util.h"
17 
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/attr_value.pb_text.h"
24 #include "tensorflow/core/framework/tensor.pb_text.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb_text.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/lib/strings/proto_serialization.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/fingerprint.h"
34 #include "tensorflow/core/platform/protobuf.h"
35 
36 namespace tensorflow {
37 namespace {
38 
39 // Do not construct large tensors to compute their hash or compare for equality.
40 constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024;  // 32mb
41 
42 // Limit nesting of tensors to 100 deep to prevent memory overflow.
43 constexpr int kMaxTensorNestDepth = 100;
44 
45 // Return the size of the tensor represented by this TensorProto. If shape is
46 // not fully defined return -1.
TensorByteSize(const TensorProto & t)47 int64_t TensorByteSize(const TensorProto& t) {
48   // num_elements returns -1 if shape is not fully defined.
49   int64_t num_elems = PartialTensorShape(t.tensor_shape()).num_elements();
50   return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
51 }
52 
53 // Compute TensorProto hash by creating a Tensor, serializing it as tensor
54 // content, and computing a hash of it's string representation. This is unsafe
55 // operation, because large tensors can be represented as TensorProto, but can't
56 // be serialized to tensor content.
TensorProtoHash(const TensorProto & tp)57 uint64 TensorProtoHash(const TensorProto& tp) {
58   Tensor tensor(tp.dtype());
59   bool success = tensor.FromProto(tp);
60   DCHECK(success);
61   TensorProto p;
62   tensor.AsProtoTensorContent(&p);
63   return DeterministicProtoHash64(p);
64 }
65 
66 // Do not create large tensors in memory, compute hash based on TensorProto
67 // string representation. Tensors with identical content potentially can have a
68 // different hash code if they are defined with different TensorProto
69 // representations.
FastTensorProtoHash(const TensorProto & tp)70 uint64 FastTensorProtoHash(const TensorProto& tp) {
71   if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
72     return DeterministicProtoHash64(tp);
73   } else {
74     return TensorProtoHash(tp);
75   }
76 }
77 
AreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs,bool allow_false_negatives)78 bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs,
79                           bool allow_false_negatives) {
80   // A small TensorProto can expand into a giant Tensor.  So we avoid
81   // conversion to an actual Tensor if we can quickly rule out equality
82   // by comparing the Tensor size since different sized Tensors are definitely
83   // different.
84   const int64_t lhs_tensor_bytes = TensorByteSize(lhs);
85   const int64_t rhs_tensor_bytes = TensorByteSize(rhs);
86   if (lhs_tensor_bytes != rhs_tensor_bytes) {
87     return false;
88   }
89 
90   // If the TensorProto representation expands into a much bigger Tensor,
91   // we have a fast-path that first compares the protos.
92   const int64_t lhs_proto_bytes = lhs.ByteSizeLong();
93   const bool large_expansion =
94       (lhs_proto_bytes < 512 && lhs_tensor_bytes > 4096);
95 
96   // If the tensor is very large, we'll only compare the proto representation if
97   // false negatives are allowed. This may miss some equivalent tensors whose
98   // actual tensor values are the same but which are described by different
99   // TensorProtos. This avoids construction of large protos in memory.
100   const bool only_compare_proto =
101       (allow_false_negatives && lhs_tensor_bytes > kMaxAttrValueTensorByteSize);
102   if (large_expansion || only_compare_proto) {
103     if (AreSerializedProtosEqual(lhs, rhs))
104       return true;
105     else if (only_compare_proto)
106       return false;
107   }
108 
109   // Finally, compare them by constructing Tensors and serializing them back.
110   // There are multiple equivalent representations of attr values containing
111   // TensorProtos. Comparing Tensor objects is pretty tricky. This is unsafe
112   // operation, because large tensors can be represented as TensorProto, but
113   // can't be serialized to tensor content.
114   Tensor lhs_t(lhs.dtype());
115   bool success = lhs_t.FromProto(lhs);
116   if (!success) {
117     return false;
118   }
119 
120   Tensor rhs_t(rhs.dtype());
121   success = rhs_t.FromProto(rhs);
122   if (!success) {
123     return false;
124   }
125 
126   TensorProto lhs_tp;
127   lhs_t.AsProtoTensorContent(&lhs_tp);
128 
129   TensorProto rhs_tp;
130   rhs_t.AsProtoTensorContent(&rhs_tp);
131 
132   return AreSerializedProtosEqual(lhs_tp, rhs_tp);
133 }
134 
135 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
136 
AttrValueHash(const AttrValue & a,const TensorProtoHasher & tensor_hash)137 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
138   if (a.has_tensor()) return tensor_hash(a.tensor());
139 
140   if (a.has_func()) {
141     const NameAttrList& func = a.func();
142     uint64 h = Hash64(func.name());
143     std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
144     for (const auto& pair : map) {
145       h = Hash64(pair.first.data(), pair.first.size(), h);
146       h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
147     }
148     return h;
149   }
150 
151   // If `a` is not a tensor or func, get a hash of serialized string.
152   return DeterministicProtoHash64(a);
153 }
154 
SummarizeString(const string & str)155 string SummarizeString(const string& str) {
156   string escaped = absl::CEscape(str);
157 
158   // If the string is long, replace the middle with ellipses.
159   constexpr int kMaxStringSummarySize = 80;
160   if (escaped.size() >= kMaxStringSummarySize) {
161     StringPiece prefix(escaped);
162     StringPiece suffix = prefix;
163     prefix.remove_suffix(escaped.size() - 10);
164     suffix.remove_prefix(escaped.size() - 10);
165     return strings::StrCat("\"", prefix, "...", suffix, "\"");
166   } else {
167     return strings::StrCat("\"", escaped, "\"");
168   }
169 }
170 
SummarizeTensor(const TensorProto & tensor_proto)171 string SummarizeTensor(const TensorProto& tensor_proto) {
172   Tensor t;
173   if (!t.FromProto(tensor_proto)) {
174     return strings::StrCat(
175         "<Invalid TensorProto: ", tensor_proto.ShortDebugString(), ">");
176   }
177   return t.DebugString();
178 }
179 
SummarizeFunc(const NameAttrList & func)180 string SummarizeFunc(const NameAttrList& func) {
181   std::vector<string> entries;
182   for (const auto& p : func.attr()) {
183     entries.push_back(
184         strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
185   }
186   std::sort(entries.begin(), entries.end());
187   return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
188 }
189 
ParseAttrValueHelper_TensorNestsUnderLimit(int limit,string to_parse)190 bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) {
191   int nests = 0;
192   int maxed_out = to_parse.length();
193   int open_curly = to_parse.find('{');
194   int open_bracket = to_parse.find('<');
195   int close_curly = to_parse.find('}');
196   int close_bracket = to_parse.find('>');
197   if (open_curly == -1) {
198     open_curly = maxed_out;
199   }
200   if (open_bracket == -1) {
201     open_bracket = maxed_out;
202   }
203   int min = std::min(open_curly, open_bracket);
204   do {
205     if (open_curly == maxed_out && open_bracket == maxed_out) {
206       return true;
207     }
208     if (min == open_curly) {
209       nests += 1;
210       open_curly = to_parse.find('{', open_curly + 1);
211       if (open_curly == -1) {
212         open_curly = maxed_out;
213       }
214     } else if (min == open_bracket) {
215       nests += 1;
216       open_bracket = to_parse.find('<', open_bracket + 1);
217       if (open_bracket == -1) {
218         open_bracket = maxed_out;
219       }
220     } else if (min == close_curly) {
221       nests -= 1;
222       close_curly = to_parse.find('}', close_curly + 1);
223       if (close_curly == -1) {
224         close_curly = maxed_out;
225       }
226     } else if (min == close_bracket) {
227       nests -= 1;
228       close_bracket = to_parse.find('>', close_bracket + 1);
229       if (close_bracket == -1) {
230         close_bracket = maxed_out;
231       }
232     }
233     min = std::min({open_curly, open_bracket, close_curly, close_bracket});
234   } while (nests < 100);
235   return false;
236 }
237 
238 }  // namespace
239 
SummarizeAttrValue(const AttrValue & attr_value)240 string SummarizeAttrValue(const AttrValue& attr_value) {
241   switch (attr_value.value_case()) {
242     case AttrValue::kS:
243       return SummarizeString(attr_value.s());
244     case AttrValue::kI:
245       return strings::StrCat(attr_value.i());
246     case AttrValue::kF:
247       return strings::StrCat(attr_value.f());
248     case AttrValue::kB:
249       return attr_value.b() ? "true" : "false";
250     case AttrValue::kType:
251       return EnumName_DataType(attr_value.type());
252     case AttrValue::kShape:
253       return PartialTensorShape::DebugString(attr_value.shape());
254     case AttrValue::kTensor:
255       return SummarizeTensor(attr_value.tensor());
256     case AttrValue::kList: {
257       std::vector<string> pieces;
258       if (attr_value.list().s_size() > 0) {
259         for (int i = 0; i < attr_value.list().s_size(); ++i) {
260           pieces.push_back(SummarizeString(attr_value.list().s(i)));
261         }
262       } else if (attr_value.list().i_size() > 0) {
263         for (int i = 0; i < attr_value.list().i_size(); ++i) {
264           pieces.push_back(strings::StrCat(attr_value.list().i(i)));
265         }
266       } else if (attr_value.list().f_size() > 0) {
267         for (int i = 0; i < attr_value.list().f_size(); ++i) {
268           pieces.push_back(strings::StrCat(attr_value.list().f(i)));
269         }
270       } else if (attr_value.list().b_size() > 0) {
271         for (int i = 0; i < attr_value.list().b_size(); ++i) {
272           pieces.push_back(attr_value.list().b(i) ? "true" : "false");
273         }
274       } else if (attr_value.list().type_size() > 0) {
275         for (int i = 0; i < attr_value.list().type_size(); ++i) {
276           pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
277         }
278       } else if (attr_value.list().shape_size() > 0) {
279         for (int i = 0; i < attr_value.list().shape_size(); ++i) {
280           pieces.push_back(
281               TensorShape::DebugString(attr_value.list().shape(i)));
282         }
283       } else if (attr_value.list().tensor_size() > 0) {
284         for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
285           pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
286         }
287       } else if (attr_value.list().func_size() > 0) {
288         for (int i = 0; i < attr_value.list().func_size(); ++i) {
289           pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
290         }
291       }
292       constexpr int kMaxListSummarySize = 15;
293       if (pieces.size() >= kMaxListSummarySize) {
294         pieces[5] = strings::StrCat(Fingerprint64(
295             absl::StrJoin(pieces.begin() + 5, pieces.end() - 5, ",")));
296         pieces.erase(pieces.begin() + 6, pieces.end() - 5);
297       }
298       return strings::StrCat("[", absl::StrJoin(pieces, ", "), "]");
299     }
300     case AttrValue::kFunc: {
301       return SummarizeFunc(attr_value.func());
302     }
303     case AttrValue::kPlaceholder:
304       return strings::StrCat("$", attr_value.placeholder());
305     case AttrValue::VALUE_NOT_SET:
306       return "<Unknown AttrValue type>";
307   }
308   return "<Unknown AttrValue type>";  // Prevent missing return warning
309 }
310 
AttrValueHasType(const AttrValue & attr_value,StringPiece type)311 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
312   int num_set = 0;
313 
314 #define VALIDATE_FIELD(name, type_string, oneof_case)                         \
315   do {                                                                        \
316     if (attr_value.has_list()) {                                              \
317       if (attr_value.list().name##_size() > 0) {                              \
318         if (type != "list(" type_string ")") {                                \
319           return errors::InvalidArgument(                                     \
320               "AttrValue had value with type 'list(" type_string ")' when '", \
321               type, "' expected");                                            \
322         }                                                                     \
323         ++num_set;                                                            \
324       }                                                                       \
325     } else if (attr_value.value_case() == AttrValue::oneof_case) {            \
326       if (type != type_string) {                                              \
327         return errors::InvalidArgument(                                       \
328             "AttrValue had value with type '" type_string "' when '", type,   \
329             "' expected");                                                    \
330       }                                                                       \
331       ++num_set;                                                              \
332     }                                                                         \
333   } while (false)
334 
335   VALIDATE_FIELD(s, "string", kS);
336   VALIDATE_FIELD(i, "int", kI);
337   VALIDATE_FIELD(f, "float", kF);
338   VALIDATE_FIELD(b, "bool", kB);
339   VALIDATE_FIELD(type, "type", kType);
340   VALIDATE_FIELD(shape, "shape", kShape);
341   VALIDATE_FIELD(tensor, "tensor", kTensor);
342   VALIDATE_FIELD(func, "func", kFunc);
343 
344 #undef VALIDATE_FIELD
345 
346   if (attr_value.value_case() == AttrValue::kPlaceholder) {
347     return errors::InvalidArgument(
348         "AttrValue had value with unexpected type 'placeholder'");
349   }
350 
351   // If the attr type is 'list', we expect attr_value.has_list() to be
352   // true.  However, proto3's attr_value.has_list() can be false when
353   // set to an empty list for GraphDef versions <= 4. So we simply
354   // check if has_list is false and some other field in attr_value is
355   // set to flag the error.  This test can be made more strict once
356   // support for GraphDef versions <= 4 is dropped.
357   if (absl::StartsWith(type, "list(") && !attr_value.has_list()) {
358     if (num_set) {
359       return errors::InvalidArgument(
360           "AttrValue missing value with expected type '", type, "'");
361     } else {
362       // Indicate that we have a list, but an empty one.
363       ++num_set;
364     }
365   }
366 
367   // Okay to have an empty list, but not to be missing a non-list value.
368   if (num_set == 0 && !absl::StartsWith(type, "list(")) {
369     return errors::InvalidArgument(
370         "AttrValue missing value with expected type '", type, "'");
371   }
372 
373   // Ref types and DT_INVALID are illegal, and DataTypes must
374   // be a valid enum type.
375   if (type == "type") {
376     if (!DataType_IsValid(attr_value.type())) {
377       return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
378                                      attr_value.type());
379     }
380     if (IsRefType(attr_value.type())) {
381       return errors::InvalidArgument(
382           "AttrValue must not have reference type value of ",
383           DataTypeString(attr_value.type()));
384     }
385     if (attr_value.type() == DT_INVALID) {
386       return errors::InvalidArgument("AttrValue has invalid DataType");
387     }
388   } else if (type == "list(type)") {
389     for (auto as_int : attr_value.list().type()) {
390       const DataType dtype = static_cast<DataType>(as_int);
391       if (!DataType_IsValid(dtype)) {
392         return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
393                                        as_int);
394       }
395       if (IsRefType(dtype)) {
396         return errors::InvalidArgument(
397             "AttrValue must not have reference type value of ",
398             DataTypeString(dtype));
399       }
400       if (dtype == DT_INVALID) {
401         return errors::InvalidArgument("AttrValue contains invalid DataType");
402       }
403     }
404   }
405 
406   return OkStatus();
407 }
408 
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)409 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
410   // Parse type.
411   string field_name;
412   bool is_list = absl::ConsumePrefix(&type, "list(");
413   if (absl::ConsumePrefix(&type, "string")) {
414     field_name = "s";
415   } else if (absl::ConsumePrefix(&type, "int")) {
416     field_name = "i";
417   } else if (absl::ConsumePrefix(&type, "float")) {
418     field_name = "f";
419   } else if (absl::ConsumePrefix(&type, "bool")) {
420     field_name = "b";
421   } else if (absl::ConsumePrefix(&type, "type")) {
422     field_name = "type";
423   } else if (absl::ConsumePrefix(&type, "shape")) {
424     field_name = "shape";
425   } else if (absl::ConsumePrefix(&type, "tensor")) {
426     field_name = "tensor";
427   } else if (absl::ConsumePrefix(&type, "func")) {
428     field_name = "func";
429   } else if (absl::ConsumePrefix(&type, "placeholder")) {
430     field_name = "placeholder";
431   } else {
432     return false;
433   }
434   if (is_list && !absl::ConsumePrefix(&type, ")")) {
435     return false;
436   }
437 
438   // Construct a valid text proto message to parse.
439   string to_parse;
440   if (is_list) {
441     // TextFormat parser considers "i: 7" to be the same as "i: [7]",
442     // but we only want to allow list values with [].
443     StringPiece cleaned = text;
444     str_util::RemoveLeadingWhitespace(&cleaned);
445     str_util::RemoveTrailingWhitespace(&cleaned);
446     if (cleaned.size() < 2 || cleaned[0] != '[' ||
447         cleaned[cleaned.size() - 1] != ']') {
448       return false;
449     }
450     cleaned.remove_prefix(1);
451     str_util::RemoveLeadingWhitespace(&cleaned);
452     if (cleaned.size() == 1) {
453       // User wrote "[]", so return empty list without invoking the TextFormat
454       // parse which returns an error for "i: []".
455       out->Clear();
456       out->mutable_list();
457       return true;
458     }
459     to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
460   } else {
461     to_parse = strings::StrCat(field_name, ": ", text);
462   }
463   if (field_name == "tensor") {
464     if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth,
465                                                     to_parse)) {
466       return false;
467     }
468   }
469   return ProtoParseFromString(to_parse, out);
470 }
471 
SetAttrValue(const AttrValue & value,AttrValue * out)472 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
473 
474 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
475   void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
476 
477 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD)                       \
478   void SetAttrValue(ARG_TYPE value, AttrValue* out) {                     \
479     out->mutable_list()->Clear(); /* create list() even if value empty */ \
480     for (const auto& v : value) {                                         \
481       out->mutable_list()->add_##FIELD(v);                                \
482     }                                                                     \
483   }
484 
485 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
486   DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD)        \
487   DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
488 
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)489 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
490 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
491 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
492 DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i)
493 DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i)
494 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
495 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
496 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
497 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
498 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
499 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
500 
501 void SetAttrValue(const tstring& value, AttrValue* out) {
502   out->set_s(value.data(), value.size());
503 }
504 
SetAttrValue(gtl::ArraySlice<tstring> value,AttrValue * out)505 void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
506   out->mutable_list()->Clear();
507   for (const auto& v : value) {
508     out->mutable_list()->add_s(v.data(), v.size());
509   }
510 }
511 
SetAttrValue(StringPiece value,AttrValue * out)512 void SetAttrValue(StringPiece value, AttrValue* out) {
513   out->set_s(value.data(), value.size());
514 }
515 
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)516 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
517   out->mutable_list()->Clear();  // Create list() even if value empty.
518   for (const auto& v : value) {
519     out->mutable_list()->add_s(v.data(), v.size());
520   }
521 }
522 
MoveAttrValue(std::vector<string> && value,AttrValue * out)523 void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
524   out->mutable_list()->Clear();  // Create list() even if value empty.
525   for (auto& v : value) {
526     out->mutable_list()->add_s(std::move(v));
527   }
528 }
529 
SetAttrValue(const TensorShape & value,AttrValue * out)530 void SetAttrValue(const TensorShape& value, AttrValue* out) {
531   value.AsProto(out->mutable_shape());
532 }
533 
SetAttrValue(const TensorShapeProto & value,AttrValue * out)534 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
535   *out->mutable_shape() = value;
536 }
537 
SetAttrValue(const PartialTensorShape & value,AttrValue * out)538 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
539   value.AsProto(out->mutable_shape());
540 }
541 
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)542 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
543   out->mutable_list()->Clear();  // Create list() even if value empty.
544   for (const auto& v : value) {
545     v.AsProto(out->mutable_list()->add_shape());
546   }
547 }
548 
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)549 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
550   out->mutable_list()->Clear();  // Create list() even if value empty.
551   for (const auto& v : value) {
552     *out->mutable_list()->add_shape() = v;
553   }
554 }
555 
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)556 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
557                   AttrValue* out) {
558   out->mutable_list()->Clear();  // Create list() even if value empty.
559   for (const auto& v : value) {
560     v.AsProto(out->mutable_list()->add_shape());
561   }
562 }
563 
SetAttrValue(const Tensor & value,AttrValue * out)564 void SetAttrValue(const Tensor& value, AttrValue* out) {
565   if (value.NumElements() > 1) {
566     value.AsProtoTensorContent(out->mutable_tensor());
567   } else {
568     value.AsProtoField(out->mutable_tensor());
569   }
570 }
571 
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)572 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
573   out->mutable_list()->Clear();  // Create list() even if value empty.
574   for (const auto& v : value) {
575     if (v.NumElements() > 1) {
576       v.AsProtoTensorContent(out->mutable_list()->add_tensor());
577     } else {
578       v.AsProtoField(out->mutable_list()->add_tensor());
579     }
580   }
581 }
582 
SetAttrValue(const TensorProto & value,AttrValue * out)583 void SetAttrValue(const TensorProto& value, AttrValue* out) {
584   *out->mutable_tensor() = value;
585 }
586 
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)587 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
588   out->mutable_list()->Clear();  // Create list() even if value empty.
589   for (const auto& v : value) {
590     *out->mutable_list()->add_tensor() = v;
591   }
592 }
593 
SetAttrValue(const NameAttrList & value,AttrValue * out)594 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
595   *out->mutable_func() = value;
596 }
597 
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)598 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
599   out->mutable_list()->Clear();  // Create list() even if value empty.
600   for (const auto& v : value) {
601     *out->mutable_list()->add_func() = v;
602   }
603 }
604 
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b,bool allow_false_negatives)605 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
606                         bool allow_false_negatives) {
607   if (a.type() != b.type()) {
608     return false;
609   } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
610     return a.type() == b.type();
611   }
612 
613   if (a.has_tensor() != b.has_tensor()) {
614     return false;
615   } else if (a.has_tensor() && b.has_tensor()) {
616     return AreTensorProtosEqual(a.tensor(), b.tensor(), allow_false_negatives);
617   }
618 
619   // `func` field contains a nested AttrValue. Compare such AttrValues
620   // recursively.
621   if (a.has_func() != b.has_func()) {
622     return false;
623   } else if (a.has_func() && b.has_func()) {
624     const NameAttrList& af = a.func();
625     const NameAttrList& bf = b.func();
626     if (af.name() != bf.name()) return false;
627     std::unordered_map<string, AttrValue> am(af.attr().begin(),
628                                              af.attr().end());
629     for (const auto& bm_pair : bf.attr()) {
630       const auto& iter = am.find(bm_pair.first);
631       if (iter == am.end()) return false;
632       if (!AreAttrValuesEqual(iter->second, bm_pair.second,
633                               allow_false_negatives))
634         return false;
635       am.erase(iter);
636     }
637     if (!am.empty()) return false;
638     return true;
639   }
640 
641   // All other fields in AttrValue have deterministic representations.
642   // It is safe to compare their serialized strings.
643   return AreSerializedProtosEqual(a, b);
644 }
645 
AttrValueHash(const AttrValue & a)646 uint64 AttrValueHash(const AttrValue& a) {
647   return AttrValueHash(a, TensorProtoHash);
648 }
649 
FastAttrValueHash(const AttrValue & a)650 uint64 FastAttrValueHash(const AttrValue& a) {
651   return AttrValueHash(a, FastTensorProtoHash);
652 }
653 
HasPlaceHolder(const AttrValue & val)654 bool HasPlaceHolder(const AttrValue& val) {
655   switch (val.value_case()) {
656     case AttrValue::kList: {
657       for (const NameAttrList& func : val.list().func()) {
658         for (const auto& p : func.attr()) {
659           if (HasPlaceHolder(p.second)) {
660             return true;
661           }
662         }
663       }
664       break;
665     }
666     case AttrValue::kFunc:
667       for (const auto& p : val.func().attr()) {
668         if (HasPlaceHolder(p.second)) {
669           return true;
670         }
671       }
672       break;
673     case AttrValue::kPlaceholder:
674       return true;
675     default:
676       break;
677   }
678   return false;
679 }
680 
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)681 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
682                             AttrValue* value) {
683   switch (value->value_case()) {
684     case AttrValue::kList: {
685       for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
686         for (auto& p : *func.mutable_attr()) {
687           if (!SubstitutePlaceholders(substitute, &p.second)) {
688             return false;
689           }
690         }
691       }
692       break;
693     }
694     case AttrValue::kFunc:
695       for (auto& p : *(value->mutable_func()->mutable_attr())) {
696         if (!SubstitutePlaceholders(substitute, &p.second)) {
697           return false;
698         }
699       }
700       break;
701     case AttrValue::kPlaceholder:
702       return substitute(value->placeholder(), value);
703     case AttrValue::VALUE_NOT_SET:
704       return false;
705     default:
706       break;
707   }
708   return true;
709 }
710 
711 }  // namespace tensorflow
712