• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/attr_value_util.h"
17 
18 #include <string>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/attr_value.pb_text.h"
22 #include "tensorflow/core/framework/tensor.pb_text.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb_text.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/proto_serialization.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 
33 namespace tensorflow {
34 namespace {
35 
36 // Do not construct large tensors to compute their hash or compare for equality.
37 constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024;  // 32mb
38 
39 // Return the size of the tensor represented by this TensorProto. If shape is
40 // not fully defined return -1.
TensorByteSize(const TensorProto & t)41 int64 TensorByteSize(const TensorProto& t) {
42   // num_elements returns -1 if shape is not fully defined.
43   int64 num_elems = TensorShape(t.tensor_shape()).num_elements();
44   return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
45 }
46 
47 // Compute TensorProto hash by creating a Tensor, serializing it as tensor
48 // content, and computing a hash of it's string representation. This is unsafe
49 // operation, because large tensors can be represented as TensorProto, but can't
50 // be serialized to tensor content.
TensorProtoHash(const TensorProto & tp)51 uint64 TensorProtoHash(const TensorProto& tp) {
52   Tensor tensor(tp.dtype());
53   bool success = tensor.FromProto(tp);
54   DCHECK(success);
55   TensorProto p;
56   tensor.AsProtoTensorContent(&p);
57   return DeterministicProtoHash64(p);
58 }
59 
60 // Do not create large tensors in memory, compute hash based on TensorProto
61 // string representation. Tensors with identical content potentially can have a
62 // different hash code if they are defined with different TensorProto
63 // representations.
FastTensorProtoHash(const TensorProto & tp)64 uint64 FastTensorProtoHash(const TensorProto& tp) {
65   if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
66     return DeterministicProtoHash64(tp);
67   } else {
68     return TensorProtoHash(tp);
69   }
70 }
71 
72 // There are multiple equivalent representations of attr values containing
73 // TensorProtos. Compare them by constructing Tensors and serializing them
74 // back. Comparing Tensor objects is pretty tricky. This is unsafe operation,
75 // because large tensors can be represented as TensorProto, but can't be
76 // serialized to tensor content.
AreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)77 bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
78   Tensor lhs_t(lhs.dtype());
79   bool success = lhs_t.FromProto(lhs);
80   DCHECK(success);
81 
82   Tensor rhs_t(rhs.dtype());
83   success = rhs_t.FromProto(rhs);
84   DCHECK(success);
85 
86   TensorProto lhs_tp;
87   lhs_t.AsProtoTensorContent(&lhs_tp);
88 
89   TensorProto rhs_tp;
90   rhs_t.AsProtoTensorContent(&rhs_tp);
91 
92   return AreSerializedProtosEqual(lhs_tp, rhs_tp);
93 }
94 
95 // Do not construct large tensors in memory, compare equality using TensorProto
96 // string representation. Tensors with identical content potentially can have
97 // different tensor proto representation.
FastAreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)98 bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
99   if (TensorByteSize(lhs) > kMaxAttrValueTensorByteSize ||
100       TensorByteSize(rhs) > kMaxAttrValueTensorByteSize) {
101     string lhs_str, rhs_str;
102     bool success = lhs.AppendToString(&lhs_str);
103     DCHECK(success);
104     success = rhs.AppendToString(&rhs_str);
105     DCHECK(success);
106 
107     return lhs_str == rhs_str;
108   } else {
109     return AreTensorProtosEqual(lhs, rhs);
110   }
111 }
112 
113 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
114 using TensorProtosEquality =
115     std::function<bool(const TensorProto&, const TensorProto&)>;
116 
AttrValueHash(const AttrValue & a,const TensorProtoHasher & tensor_hash)117 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
118   if (a.has_tensor()) return tensor_hash(a.tensor());
119 
120   if (a.has_func()) {
121     const NameAttrList& func = a.func();
122     uint64 h = Hash64(func.name());
123     std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
124     for (const auto& pair : map) {
125       h = Hash64(pair.first.data(), pair.first.size(), h);
126       h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
127     }
128     return h;
129   }
130 
131   // If `a` is not a tensor or func, get a hash of serialized string.
132   return DeterministicProtoHash64(a);
133 }
134 
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b,const TensorProtosEquality & tensor_equality)135 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
136                         const TensorProtosEquality& tensor_equality) {
137   if (a.has_tensor() != b.has_tensor()) {
138     return false;
139   } else if (a.has_tensor() && b.has_tensor()) {
140     return tensor_equality(a.tensor(), b.tensor());
141   }
142 
143   // `func` field contains a nested AttrValue. Compare such AttrValues
144   // recursively.
145   if (a.has_func() != b.has_func()) {
146     return false;
147   } else if (a.has_func() && b.has_func()) {
148     const NameAttrList& af = a.func();
149     const NameAttrList& bf = b.func();
150     if (af.name() != bf.name()) return false;
151     std::unordered_map<string, AttrValue> am(af.attr().begin(),
152                                              af.attr().end());
153     for (const auto& bm_pair : bf.attr()) {
154       const auto& iter = am.find(bm_pair.first);
155       if (iter == am.end()) return false;
156       if (!AreAttrValuesEqual(iter->second, bm_pair.second, tensor_equality))
157         return false;
158       am.erase(iter);
159     }
160     if (!am.empty()) return false;
161     return true;
162   }
163 
164   // All other fields in AttrValue have deterministic representations.
165   // It is safe to compare their serialized strings.
166   return AreSerializedProtosEqual(a, b);
167 }
168 
SummarizeString(const string & str)169 string SummarizeString(const string& str) {
170   string escaped = str_util::CEscape(str);
171 
172   // If the string is long, replace the middle with ellipses.
173   constexpr int kMaxStringSummarySize = 80;
174   if (escaped.size() >= kMaxStringSummarySize) {
175     StringPiece prefix(escaped);
176     StringPiece suffix = prefix;
177     prefix.remove_suffix(escaped.size() - 10);
178     suffix.remove_prefix(escaped.size() - 10);
179     return strings::StrCat("\"", prefix, "...", suffix, "\"");
180   } else {
181     return strings::StrCat("\"", escaped, "\"");
182   }
183 }
184 
SummarizeTensor(const TensorProto & tensor_proto)185 string SummarizeTensor(const TensorProto& tensor_proto) {
186   Tensor t;
187   if (!t.FromProto(tensor_proto)) {
188     return strings::StrCat(
189         "<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">");
190   }
191   return t.DebugString();
192 }
193 
SummarizeFunc(const NameAttrList & func)194 string SummarizeFunc(const NameAttrList& func) {
195   std::vector<string> entries;
196   for (auto p : func.attr()) {
197     entries.push_back(
198         strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
199   }
200   std::sort(entries.begin(), entries.end());
201   return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
202 }
203 
204 }  // namespace
205 
SummarizeAttrValue(const AttrValue & attr_value)206 string SummarizeAttrValue(const AttrValue& attr_value) {
207   switch (attr_value.value_case()) {
208     case AttrValue::kS:
209       return SummarizeString(attr_value.s());
210     case AttrValue::kI:
211       return strings::StrCat(attr_value.i());
212     case AttrValue::kF:
213       return strings::StrCat(attr_value.f());
214     case AttrValue::kB:
215       return attr_value.b() ? "true" : "false";
216     case AttrValue::kType:
217       return EnumName_DataType(attr_value.type());
218     case AttrValue::kShape:
219       return PartialTensorShape::DebugString(attr_value.shape());
220     case AttrValue::kTensor:
221       return SummarizeTensor(attr_value.tensor());
222     case AttrValue::kList: {
223       std::vector<string> pieces;
224       if (attr_value.list().s_size() > 0) {
225         for (int i = 0; i < attr_value.list().s_size(); ++i) {
226           pieces.push_back(SummarizeString(attr_value.list().s(i)));
227         }
228       } else if (attr_value.list().i_size() > 0) {
229         for (int i = 0; i < attr_value.list().i_size(); ++i) {
230           pieces.push_back(strings::StrCat(attr_value.list().i(i)));
231         }
232       } else if (attr_value.list().f_size() > 0) {
233         for (int i = 0; i < attr_value.list().f_size(); ++i) {
234           pieces.push_back(strings::StrCat(attr_value.list().f(i)));
235         }
236       } else if (attr_value.list().b_size() > 0) {
237         for (int i = 0; i < attr_value.list().b_size(); ++i) {
238           pieces.push_back(attr_value.list().b(i) ? "true" : "false");
239         }
240       } else if (attr_value.list().type_size() > 0) {
241         for (int i = 0; i < attr_value.list().type_size(); ++i) {
242           pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
243         }
244       } else if (attr_value.list().shape_size() > 0) {
245         for (int i = 0; i < attr_value.list().shape_size(); ++i) {
246           pieces.push_back(
247               TensorShape::DebugString(attr_value.list().shape(i)));
248         }
249       } else if (attr_value.list().tensor_size() > 0) {
250         for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
251           pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
252         }
253       } else if (attr_value.list().func_size() > 0) {
254         for (int i = 0; i < attr_value.list().func_size(); ++i) {
255           pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
256         }
257       }
258       constexpr int kMaxListSummarySize = 15;
259       if (pieces.size() >= kMaxListSummarySize) {
260         pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
261         pieces[5] = "...";
262       }
263       return strings::StrCat("[", str_util::Join(pieces, ", "), "]");
264     }
265     case AttrValue::kFunc: {
266       return SummarizeFunc(attr_value.func());
267     }
268     case AttrValue::kPlaceholder:
269       return strings::StrCat("$", attr_value.placeholder());
270     case AttrValue::VALUE_NOT_SET:
271       return "<Unknown AttrValue type>";
272   }
273   return "<Unknown AttrValue type>";  // Prevent missing return warning
274 }
275 
AttrValueHasType(const AttrValue & attr_value,StringPiece type)276 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
277   int num_set = 0;
278 
279 #define VALIDATE_FIELD(name, type_string, oneof_case)                         \
280   do {                                                                        \
281     if (attr_value.has_list()) {                                              \
282       if (attr_value.list().name##_size() > 0) {                              \
283         if (type != "list(" type_string ")") {                                \
284           return errors::InvalidArgument(                                     \
285               "AttrValue had value with type 'list(" type_string ")' when '", \
286               type, "' expected");                                            \
287         }                                                                     \
288         ++num_set;                                                            \
289       }                                                                       \
290     } else if (attr_value.value_case() == AttrValue::oneof_case) {            \
291       if (type != type_string) {                                              \
292         return errors::InvalidArgument(                                       \
293             "AttrValue had value with type '" type_string "' when '", type,   \
294             "' expected");                                                    \
295       }                                                                       \
296       ++num_set;                                                              \
297     }                                                                         \
298   } while (false)
299 
300   VALIDATE_FIELD(s, "string", kS);
301   VALIDATE_FIELD(i, "int", kI);
302   VALIDATE_FIELD(f, "float", kF);
303   VALIDATE_FIELD(b, "bool", kB);
304   VALIDATE_FIELD(type, "type", kType);
305   VALIDATE_FIELD(shape, "shape", kShape);
306   VALIDATE_FIELD(tensor, "tensor", kTensor);
307   VALIDATE_FIELD(func, "func", kFunc);
308 
309 #undef VALIDATE_FIELD
310 
311   if (attr_value.value_case() == AttrValue::kPlaceholder) {
312     return errors::InvalidArgument(
313         "AttrValue had value with unexpected type 'placeholder'");
314   }
315 
316   // If the attr type is 'list', we expect attr_value.has_list() to be
317   // true.  However, proto3's attr_value.has_list() can be false when
318   // set to an empty list for GraphDef versions <= 4. So we simply
319   // check if has_list is false and some other field in attr_value is
320   // set to flag the error.  This test can be made more strict once
321   // support for GraphDef versions <= 4 is dropped.
322   if (str_util::StartsWith(type, "list(") && !attr_value.has_list()) {
323     if (num_set) {
324       return errors::InvalidArgument(
325           "AttrValue missing value with expected type '", type, "'");
326     } else {
327       // Indicate that we have a list, but an empty one.
328       ++num_set;
329     }
330   }
331 
332   // Okay to have an empty list, but not to be missing a non-list value.
333   if (num_set == 0 && !str_util::StartsWith(type, "list(")) {
334     return errors::InvalidArgument(
335         "AttrValue missing value with expected type '", type, "'");
336   }
337 
338   // Ref types and DT_INVALID are illegal, and DataTypes must
339   // be a valid enum type.
340   if (type == "type") {
341     if (!DataType_IsValid(attr_value.type())) {
342       return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
343                                      attr_value.type());
344     }
345     if (IsRefType(attr_value.type())) {
346       return errors::InvalidArgument(
347           "AttrValue must not have reference type value of ",
348           DataTypeString(attr_value.type()));
349     }
350     if (attr_value.type() == DT_INVALID) {
351       return errors::InvalidArgument("AttrValue has invalid DataType");
352     }
353   } else if (type == "list(type)") {
354     for (auto as_int : attr_value.list().type()) {
355       const DataType dtype = static_cast<DataType>(as_int);
356       if (!DataType_IsValid(dtype)) {
357         return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
358                                        as_int);
359       }
360       if (IsRefType(dtype)) {
361         return errors::InvalidArgument(
362             "AttrValue must not have reference type value of ",
363             DataTypeString(dtype));
364       }
365       if (dtype == DT_INVALID) {
366         return errors::InvalidArgument("AttrValue contains invalid DataType");
367       }
368     }
369   }
370 
371   return Status::OK();
372 }
373 
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)374 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
375   // Parse type.
376   string field_name;
377   bool is_list = str_util::ConsumePrefix(&type, "list(");
378   if (str_util::ConsumePrefix(&type, "string")) {
379     field_name = "s";
380   } else if (str_util::ConsumePrefix(&type, "int")) {
381     field_name = "i";
382   } else if (str_util::ConsumePrefix(&type, "float")) {
383     field_name = "f";
384   } else if (str_util::ConsumePrefix(&type, "bool")) {
385     field_name = "b";
386   } else if (str_util::ConsumePrefix(&type, "type")) {
387     field_name = "type";
388   } else if (str_util::ConsumePrefix(&type, "shape")) {
389     field_name = "shape";
390   } else if (str_util::ConsumePrefix(&type, "tensor")) {
391     field_name = "tensor";
392   } else if (str_util::ConsumePrefix(&type, "func")) {
393     field_name = "func";
394   } else if (str_util::ConsumePrefix(&type, "placeholder")) {
395     field_name = "placeholder";
396   } else {
397     return false;
398   }
399   if (is_list && !str_util::ConsumePrefix(&type, ")")) {
400     return false;
401   }
402 
403   // Construct a valid text proto message to parse.
404   string to_parse;
405   if (is_list) {
406     // TextFormat parser considers "i: 7" to be the same as "i: [7]",
407     // but we only want to allow list values with [].
408     StringPiece cleaned = text;
409     str_util::RemoveLeadingWhitespace(&cleaned);
410     str_util::RemoveTrailingWhitespace(&cleaned);
411     if (cleaned.size() < 2 || cleaned[0] != '[' ||
412         cleaned[cleaned.size() - 1] != ']') {
413       return false;
414     }
415     cleaned.remove_prefix(1);
416     str_util::RemoveLeadingWhitespace(&cleaned);
417     if (cleaned.size() == 1) {
418       // User wrote "[]", so return empty list without invoking the TextFormat
419       // parse which returns an error for "i: []".
420       out->Clear();
421       out->mutable_list();
422       return true;
423     }
424     to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
425   } else {
426     to_parse = strings::StrCat(field_name, ": ", text);
427   }
428 
429   return ProtoParseFromString(to_parse, out);
430 }
431 
SetAttrValue(const AttrValue & value,AttrValue * out)432 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
433 
434 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
435   void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
436 
437 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD)                       \
438   void SetAttrValue(ARG_TYPE value, AttrValue* out) {                     \
439     out->mutable_list()->Clear(); /* create list() even if value empty */ \
440     for (const auto& v : value) {                                         \
441       out->mutable_list()->add_##FIELD(v);                                \
442     }                                                                     \
443   }
444 
445 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
446   DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD)        \
447   DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
448 
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)449 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
450 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
451 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
452 DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
453 DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
454 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
455 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
456 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
457 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
458 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
459 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
460 
461 void SetAttrValue(StringPiece value, AttrValue* out) {
462   out->set_s(value.data(), value.size());
463 }
464 
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)465 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
466   out->mutable_list()->Clear();  // Create list() even if value empty.
467   for (const auto& v : value) {
468     out->mutable_list()->add_s(v.data(), v.size());
469   }
470 }
471 
SetAttrValue(const TensorShape & value,AttrValue * out)472 void SetAttrValue(const TensorShape& value, AttrValue* out) {
473   value.AsProto(out->mutable_shape());
474 }
475 
SetAttrValue(const TensorShapeProto & value,AttrValue * out)476 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
477   *out->mutable_shape() = value;
478 }
479 
SetAttrValue(const PartialTensorShape & value,AttrValue * out)480 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
481   value.AsProto(out->mutable_shape());
482 }
483 
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)484 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
485   out->mutable_list()->Clear();  // Create list() even if value empty.
486   for (const auto& v : value) {
487     v.AsProto(out->mutable_list()->add_shape());
488   }
489 }
490 
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)491 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
492   out->mutable_list()->Clear();  // Create list() even if value empty.
493   for (const auto& v : value) {
494     *out->mutable_list()->add_shape() = v;
495   }
496 }
497 
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)498 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
499                   AttrValue* out) {
500   out->mutable_list()->Clear();  // Create list() even if value empty.
501   for (const auto& v : value) {
502     v.AsProto(out->mutable_list()->add_shape());
503   }
504 }
505 
SetAttrValue(const Tensor & value,AttrValue * out)506 void SetAttrValue(const Tensor& value, AttrValue* out) {
507   if (value.NumElements() > 1) {
508     value.AsProtoTensorContent(out->mutable_tensor());
509   } else {
510     value.AsProtoField(out->mutable_tensor());
511   }
512 }
513 
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)514 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
515   out->mutable_list()->Clear();  // Create list() even if value empty.
516   for (const auto& v : value) {
517     if (v.NumElements() > 1) {
518       v.AsProtoTensorContent(out->mutable_list()->add_tensor());
519     } else {
520       v.AsProtoField(out->mutable_list()->add_tensor());
521     }
522   }
523 }
524 
SetAttrValue(const TensorProto & value,AttrValue * out)525 void SetAttrValue(const TensorProto& value, AttrValue* out) {
526   *out->mutable_tensor() = value;
527 }
528 
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)529 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
530   out->mutable_list()->Clear();  // Create list() even if value empty.
531   for (const auto& v : value) {
532     *out->mutable_list()->add_tensor() = v;
533   }
534 }
535 
SetAttrValue(const NameAttrList & value,AttrValue * out)536 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
537   *out->mutable_func() = value;
538 }
539 
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)540 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
541   out->mutable_list()->Clear();  // Create list() even if value empty.
542   for (const auto& v : value) {
543     *out->mutable_list()->add_func() = v;
544   }
545 }
546 
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b)547 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
548   return AreAttrValuesEqual(a, b, AreTensorProtosEqual);
549 }
550 
AttrValueHash(const AttrValue & a)551 uint64 AttrValueHash(const AttrValue& a) {
552   return AttrValueHash(a, TensorProtoHash);
553 }
554 
FastAreAttrValuesEqual(const AttrValue & a,const AttrValue & b)555 bool FastAreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
556   return AreAttrValuesEqual(a, b, FastAreTensorProtosEqual);
557 }
558 
FastAttrValueHash(const AttrValue & a)559 uint64 FastAttrValueHash(const AttrValue& a) {
560   return AttrValueHash(a, FastTensorProtoHash);
561 }
562 
HasPlaceHolder(const AttrValue & val)563 bool HasPlaceHolder(const AttrValue& val) {
564   switch (val.value_case()) {
565     case AttrValue::kList: {
566       for (const NameAttrList& func : val.list().func()) {
567         for (const auto& p : func.attr()) {
568           if (HasPlaceHolder(p.second)) {
569             return true;
570           }
571         }
572       }
573       break;
574     }
575     case AttrValue::kFunc:
576       for (const auto& p : val.func().attr()) {
577         if (HasPlaceHolder(p.second)) {
578           return true;
579         }
580       }
581       break;
582     case AttrValue::kPlaceholder:
583       return true;
584     default:
585       break;
586   }
587   return false;
588 }
589 
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)590 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
591                             AttrValue* value) {
592   switch (value->value_case()) {
593     case AttrValue::kList: {
594       for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
595         for (auto& p : *func.mutable_attr()) {
596           if (!SubstitutePlaceholders(substitute, &p.second)) {
597             return false;
598           }
599         }
600       }
601       break;
602     }
603     case AttrValue::kFunc:
604       for (auto& p : *(value->mutable_func()->mutable_attr())) {
605         if (!SubstitutePlaceholders(substitute, &p.second)) {
606           return false;
607         }
608       }
609       break;
610     case AttrValue::kPlaceholder:
611       return substitute(value->placeholder(), value);
612     case AttrValue::VALUE_NOT_SET:
613       return false;
614     default:
615       break;
616   }
617   return true;
618 }
619 
620 }  // namespace tensorflow
621