1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/attr_value_util.h"
17
18 #include <string>
19 #include <vector>
20
21 #include "tensorflow/core/framework/attr_value.pb_text.h"
22 #include "tensorflow/core/framework/tensor.pb_text.h"
23 #include "tensorflow/core/framework/tensor_shape.pb.h"
24 #include "tensorflow/core/framework/types.h"
25 #include "tensorflow/core/framework/types.pb_text.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/lib/core/stringpiece.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29 #include "tensorflow/core/lib/strings/proto_serialization.h"
30 #include "tensorflow/core/lib/strings/str_util.h"
31 #include "tensorflow/core/platform/protobuf.h"
32
33 namespace tensorflow {
34 namespace {
35
36 // Do not construct large tensors to compute their hash or compare for equality.
37 constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
38
39 // Return the size of the tensor represented by this TensorProto. If shape is
40 // not fully defined return -1.
TensorByteSize(const TensorProto & t)41 int64 TensorByteSize(const TensorProto& t) {
42 // num_elements returns -1 if shape is not fully defined.
43 int64 num_elems = TensorShape(t.tensor_shape()).num_elements();
44 return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
45 }
46
47 // Compute TensorProto hash by creating a Tensor, serializing it as tensor
48 // content, and computing a hash of it's string representation. This is unsafe
49 // operation, because large tensors can be represented as TensorProto, but can't
50 // be serialized to tensor content.
TensorProtoHash(const TensorProto & tp)51 uint64 TensorProtoHash(const TensorProto& tp) {
52 Tensor tensor(tp.dtype());
53 bool success = tensor.FromProto(tp);
54 DCHECK(success);
55 TensorProto p;
56 tensor.AsProtoTensorContent(&p);
57 return DeterministicProtoHash64(p);
58 }
59
60 // Do not create large tensors in memory, compute hash based on TensorProto
61 // string representation. Tensors with identical content potentially can have a
62 // different hash code if they are defined with different TensorProto
63 // representations.
FastTensorProtoHash(const TensorProto & tp)64 uint64 FastTensorProtoHash(const TensorProto& tp) {
65 if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
66 return DeterministicProtoHash64(tp);
67 } else {
68 return TensorProtoHash(tp);
69 }
70 }
71
72 // There are multiple equivalent representations of attr values containing
73 // TensorProtos. Compare them by constructing Tensors and serializing them
74 // back. Comparing Tensor objects is pretty tricky. This is unsafe operation,
75 // because large tensors can be represented as TensorProto, but can't be
76 // serialized to tensor content.
AreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)77 bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
78 Tensor lhs_t(lhs.dtype());
79 bool success = lhs_t.FromProto(lhs);
80 DCHECK(success);
81
82 Tensor rhs_t(rhs.dtype());
83 success = rhs_t.FromProto(rhs);
84 DCHECK(success);
85
86 TensorProto lhs_tp;
87 lhs_t.AsProtoTensorContent(&lhs_tp);
88
89 TensorProto rhs_tp;
90 rhs_t.AsProtoTensorContent(&rhs_tp);
91
92 return AreSerializedProtosEqual(lhs_tp, rhs_tp);
93 }
94
95 // Do not construct large tensors in memory, compare equality using TensorProto
96 // string representation. Tensors with identical content potentially can have
97 // different tensor proto representation.
FastAreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)98 bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
99 if (TensorByteSize(lhs) > kMaxAttrValueTensorByteSize ||
100 TensorByteSize(rhs) > kMaxAttrValueTensorByteSize) {
101 string lhs_str, rhs_str;
102 bool success = lhs.AppendToString(&lhs_str);
103 DCHECK(success);
104 success = rhs.AppendToString(&rhs_str);
105 DCHECK(success);
106
107 return lhs_str == rhs_str;
108 } else {
109 return AreTensorProtosEqual(lhs, rhs);
110 }
111 }
112
113 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
114 using TensorProtosEquality =
115 std::function<bool(const TensorProto&, const TensorProto&)>;
116
AttrValueHash(const AttrValue & a,const TensorProtoHasher & tensor_hash)117 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
118 if (a.has_tensor()) return tensor_hash(a.tensor());
119
120 if (a.has_func()) {
121 const NameAttrList& func = a.func();
122 uint64 h = Hash64(func.name());
123 std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
124 for (const auto& pair : map) {
125 h = Hash64(pair.first.data(), pair.first.size(), h);
126 h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
127 }
128 return h;
129 }
130
131 // If `a` is not a tensor or func, get a hash of serialized string.
132 return DeterministicProtoHash64(a);
133 }
134
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b,const TensorProtosEquality & tensor_equality)135 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
136 const TensorProtosEquality& tensor_equality) {
137 if (a.has_tensor() != b.has_tensor()) {
138 return false;
139 } else if (a.has_tensor() && b.has_tensor()) {
140 return tensor_equality(a.tensor(), b.tensor());
141 }
142
143 // `func` field contains a nested AttrValue. Compare such AttrValues
144 // recursively.
145 if (a.has_func() != b.has_func()) {
146 return false;
147 } else if (a.has_func() && b.has_func()) {
148 const NameAttrList& af = a.func();
149 const NameAttrList& bf = b.func();
150 if (af.name() != bf.name()) return false;
151 std::unordered_map<string, AttrValue> am(af.attr().begin(),
152 af.attr().end());
153 for (const auto& bm_pair : bf.attr()) {
154 const auto& iter = am.find(bm_pair.first);
155 if (iter == am.end()) return false;
156 if (!AreAttrValuesEqual(iter->second, bm_pair.second, tensor_equality))
157 return false;
158 am.erase(iter);
159 }
160 if (!am.empty()) return false;
161 return true;
162 }
163
164 // All other fields in AttrValue have deterministic representations.
165 // It is safe to compare their serialized strings.
166 return AreSerializedProtosEqual(a, b);
167 }
168
SummarizeString(const string & str)169 string SummarizeString(const string& str) {
170 string escaped = str_util::CEscape(str);
171
172 // If the string is long, replace the middle with ellipses.
173 constexpr int kMaxStringSummarySize = 80;
174 if (escaped.size() >= kMaxStringSummarySize) {
175 StringPiece prefix(escaped);
176 StringPiece suffix = prefix;
177 prefix.remove_suffix(escaped.size() - 10);
178 suffix.remove_prefix(escaped.size() - 10);
179 return strings::StrCat("\"", prefix, "...", suffix, "\"");
180 } else {
181 return strings::StrCat("\"", escaped, "\"");
182 }
183 }
184
SummarizeTensor(const TensorProto & tensor_proto)185 string SummarizeTensor(const TensorProto& tensor_proto) {
186 Tensor t;
187 if (!t.FromProto(tensor_proto)) {
188 return strings::StrCat(
189 "<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">");
190 }
191 return t.DebugString();
192 }
193
SummarizeFunc(const NameAttrList & func)194 string SummarizeFunc(const NameAttrList& func) {
195 std::vector<string> entries;
196 for (auto p : func.attr()) {
197 entries.push_back(
198 strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
199 }
200 std::sort(entries.begin(), entries.end());
201 return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
202 }
203
204 } // namespace
205
SummarizeAttrValue(const AttrValue & attr_value)206 string SummarizeAttrValue(const AttrValue& attr_value) {
207 switch (attr_value.value_case()) {
208 case AttrValue::kS:
209 return SummarizeString(attr_value.s());
210 case AttrValue::kI:
211 return strings::StrCat(attr_value.i());
212 case AttrValue::kF:
213 return strings::StrCat(attr_value.f());
214 case AttrValue::kB:
215 return attr_value.b() ? "true" : "false";
216 case AttrValue::kType:
217 return EnumName_DataType(attr_value.type());
218 case AttrValue::kShape:
219 return PartialTensorShape::DebugString(attr_value.shape());
220 case AttrValue::kTensor:
221 return SummarizeTensor(attr_value.tensor());
222 case AttrValue::kList: {
223 std::vector<string> pieces;
224 if (attr_value.list().s_size() > 0) {
225 for (int i = 0; i < attr_value.list().s_size(); ++i) {
226 pieces.push_back(SummarizeString(attr_value.list().s(i)));
227 }
228 } else if (attr_value.list().i_size() > 0) {
229 for (int i = 0; i < attr_value.list().i_size(); ++i) {
230 pieces.push_back(strings::StrCat(attr_value.list().i(i)));
231 }
232 } else if (attr_value.list().f_size() > 0) {
233 for (int i = 0; i < attr_value.list().f_size(); ++i) {
234 pieces.push_back(strings::StrCat(attr_value.list().f(i)));
235 }
236 } else if (attr_value.list().b_size() > 0) {
237 for (int i = 0; i < attr_value.list().b_size(); ++i) {
238 pieces.push_back(attr_value.list().b(i) ? "true" : "false");
239 }
240 } else if (attr_value.list().type_size() > 0) {
241 for (int i = 0; i < attr_value.list().type_size(); ++i) {
242 pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
243 }
244 } else if (attr_value.list().shape_size() > 0) {
245 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
246 pieces.push_back(
247 TensorShape::DebugString(attr_value.list().shape(i)));
248 }
249 } else if (attr_value.list().tensor_size() > 0) {
250 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
251 pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
252 }
253 } else if (attr_value.list().func_size() > 0) {
254 for (int i = 0; i < attr_value.list().func_size(); ++i) {
255 pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
256 }
257 }
258 constexpr int kMaxListSummarySize = 15;
259 if (pieces.size() >= kMaxListSummarySize) {
260 pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
261 pieces[5] = "...";
262 }
263 return strings::StrCat("[", str_util::Join(pieces, ", "), "]");
264 }
265 case AttrValue::kFunc: {
266 return SummarizeFunc(attr_value.func());
267 }
268 case AttrValue::kPlaceholder:
269 return strings::StrCat("$", attr_value.placeholder());
270 case AttrValue::VALUE_NOT_SET:
271 return "<Unknown AttrValue type>";
272 }
273 return "<Unknown AttrValue type>"; // Prevent missing return warning
274 }
275
AttrValueHasType(const AttrValue & attr_value,StringPiece type)276 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
277 int num_set = 0;
278
279 #define VALIDATE_FIELD(name, type_string, oneof_case) \
280 do { \
281 if (attr_value.has_list()) { \
282 if (attr_value.list().name##_size() > 0) { \
283 if (type != "list(" type_string ")") { \
284 return errors::InvalidArgument( \
285 "AttrValue had value with type 'list(" type_string ")' when '", \
286 type, "' expected"); \
287 } \
288 ++num_set; \
289 } \
290 } else if (attr_value.value_case() == AttrValue::oneof_case) { \
291 if (type != type_string) { \
292 return errors::InvalidArgument( \
293 "AttrValue had value with type '" type_string "' when '", type, \
294 "' expected"); \
295 } \
296 ++num_set; \
297 } \
298 } while (false)
299
300 VALIDATE_FIELD(s, "string", kS);
301 VALIDATE_FIELD(i, "int", kI);
302 VALIDATE_FIELD(f, "float", kF);
303 VALIDATE_FIELD(b, "bool", kB);
304 VALIDATE_FIELD(type, "type", kType);
305 VALIDATE_FIELD(shape, "shape", kShape);
306 VALIDATE_FIELD(tensor, "tensor", kTensor);
307 VALIDATE_FIELD(func, "func", kFunc);
308
309 #undef VALIDATE_FIELD
310
311 if (attr_value.value_case() == AttrValue::kPlaceholder) {
312 return errors::InvalidArgument(
313 "AttrValue had value with unexpected type 'placeholder'");
314 }
315
316 // If the attr type is 'list', we expect attr_value.has_list() to be
317 // true. However, proto3's attr_value.has_list() can be false when
318 // set to an empty list for GraphDef versions <= 4. So we simply
319 // check if has_list is false and some other field in attr_value is
320 // set to flag the error. This test can be made more strict once
321 // support for GraphDef versions <= 4 is dropped.
322 if (str_util::StartsWith(type, "list(") && !attr_value.has_list()) {
323 if (num_set) {
324 return errors::InvalidArgument(
325 "AttrValue missing value with expected type '", type, "'");
326 } else {
327 // Indicate that we have a list, but an empty one.
328 ++num_set;
329 }
330 }
331
332 // Okay to have an empty list, but not to be missing a non-list value.
333 if (num_set == 0 && !str_util::StartsWith(type, "list(")) {
334 return errors::InvalidArgument(
335 "AttrValue missing value with expected type '", type, "'");
336 }
337
338 // Ref types and DT_INVALID are illegal, and DataTypes must
339 // be a valid enum type.
340 if (type == "type") {
341 if (!DataType_IsValid(attr_value.type())) {
342 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
343 attr_value.type());
344 }
345 if (IsRefType(attr_value.type())) {
346 return errors::InvalidArgument(
347 "AttrValue must not have reference type value of ",
348 DataTypeString(attr_value.type()));
349 }
350 if (attr_value.type() == DT_INVALID) {
351 return errors::InvalidArgument("AttrValue has invalid DataType");
352 }
353 } else if (type == "list(type)") {
354 for (auto as_int : attr_value.list().type()) {
355 const DataType dtype = static_cast<DataType>(as_int);
356 if (!DataType_IsValid(dtype)) {
357 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
358 as_int);
359 }
360 if (IsRefType(dtype)) {
361 return errors::InvalidArgument(
362 "AttrValue must not have reference type value of ",
363 DataTypeString(dtype));
364 }
365 if (dtype == DT_INVALID) {
366 return errors::InvalidArgument("AttrValue contains invalid DataType");
367 }
368 }
369 }
370
371 return Status::OK();
372 }
373
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)374 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
375 // Parse type.
376 string field_name;
377 bool is_list = str_util::ConsumePrefix(&type, "list(");
378 if (str_util::ConsumePrefix(&type, "string")) {
379 field_name = "s";
380 } else if (str_util::ConsumePrefix(&type, "int")) {
381 field_name = "i";
382 } else if (str_util::ConsumePrefix(&type, "float")) {
383 field_name = "f";
384 } else if (str_util::ConsumePrefix(&type, "bool")) {
385 field_name = "b";
386 } else if (str_util::ConsumePrefix(&type, "type")) {
387 field_name = "type";
388 } else if (str_util::ConsumePrefix(&type, "shape")) {
389 field_name = "shape";
390 } else if (str_util::ConsumePrefix(&type, "tensor")) {
391 field_name = "tensor";
392 } else if (str_util::ConsumePrefix(&type, "func")) {
393 field_name = "func";
394 } else if (str_util::ConsumePrefix(&type, "placeholder")) {
395 field_name = "placeholder";
396 } else {
397 return false;
398 }
399 if (is_list && !str_util::ConsumePrefix(&type, ")")) {
400 return false;
401 }
402
403 // Construct a valid text proto message to parse.
404 string to_parse;
405 if (is_list) {
406 // TextFormat parser considers "i: 7" to be the same as "i: [7]",
407 // but we only want to allow list values with [].
408 StringPiece cleaned = text;
409 str_util::RemoveLeadingWhitespace(&cleaned);
410 str_util::RemoveTrailingWhitespace(&cleaned);
411 if (cleaned.size() < 2 || cleaned[0] != '[' ||
412 cleaned[cleaned.size() - 1] != ']') {
413 return false;
414 }
415 cleaned.remove_prefix(1);
416 str_util::RemoveLeadingWhitespace(&cleaned);
417 if (cleaned.size() == 1) {
418 // User wrote "[]", so return empty list without invoking the TextFormat
419 // parse which returns an error for "i: []".
420 out->Clear();
421 out->mutable_list();
422 return true;
423 }
424 to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
425 } else {
426 to_parse = strings::StrCat(field_name, ": ", text);
427 }
428
429 return ProtoParseFromString(to_parse, out);
430 }
431
SetAttrValue(const AttrValue & value,AttrValue * out)432 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
433
434 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
435 void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
436
437 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
438 void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
439 out->mutable_list()->Clear(); /* create list() even if value empty */ \
440 for (const auto& v : value) { \
441 out->mutable_list()->add_##FIELD(v); \
442 } \
443 }
444
445 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
446 DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
447 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
448
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)449 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
450 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
451 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
452 DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
453 DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
454 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
455 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
456 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
457 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
458 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
459 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
460
461 void SetAttrValue(StringPiece value, AttrValue* out) {
462 out->set_s(value.data(), value.size());
463 }
464
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)465 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
466 out->mutable_list()->Clear(); // Create list() even if value empty.
467 for (const auto& v : value) {
468 out->mutable_list()->add_s(v.data(), v.size());
469 }
470 }
471
SetAttrValue(const TensorShape & value,AttrValue * out)472 void SetAttrValue(const TensorShape& value, AttrValue* out) {
473 value.AsProto(out->mutable_shape());
474 }
475
SetAttrValue(const TensorShapeProto & value,AttrValue * out)476 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
477 *out->mutable_shape() = value;
478 }
479
SetAttrValue(const PartialTensorShape & value,AttrValue * out)480 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
481 value.AsProto(out->mutable_shape());
482 }
483
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)484 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
485 out->mutable_list()->Clear(); // Create list() even if value empty.
486 for (const auto& v : value) {
487 v.AsProto(out->mutable_list()->add_shape());
488 }
489 }
490
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)491 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
492 out->mutable_list()->Clear(); // Create list() even if value empty.
493 for (const auto& v : value) {
494 *out->mutable_list()->add_shape() = v;
495 }
496 }
497
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)498 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
499 AttrValue* out) {
500 out->mutable_list()->Clear(); // Create list() even if value empty.
501 for (const auto& v : value) {
502 v.AsProto(out->mutable_list()->add_shape());
503 }
504 }
505
SetAttrValue(const Tensor & value,AttrValue * out)506 void SetAttrValue(const Tensor& value, AttrValue* out) {
507 if (value.NumElements() > 1) {
508 value.AsProtoTensorContent(out->mutable_tensor());
509 } else {
510 value.AsProtoField(out->mutable_tensor());
511 }
512 }
513
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)514 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
515 out->mutable_list()->Clear(); // Create list() even if value empty.
516 for (const auto& v : value) {
517 if (v.NumElements() > 1) {
518 v.AsProtoTensorContent(out->mutable_list()->add_tensor());
519 } else {
520 v.AsProtoField(out->mutable_list()->add_tensor());
521 }
522 }
523 }
524
SetAttrValue(const TensorProto & value,AttrValue * out)525 void SetAttrValue(const TensorProto& value, AttrValue* out) {
526 *out->mutable_tensor() = value;
527 }
528
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)529 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
530 out->mutable_list()->Clear(); // Create list() even if value empty.
531 for (const auto& v : value) {
532 *out->mutable_list()->add_tensor() = v;
533 }
534 }
535
SetAttrValue(const NameAttrList & value,AttrValue * out)536 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
537 *out->mutable_func() = value;
538 }
539
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)540 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
541 out->mutable_list()->Clear(); // Create list() even if value empty.
542 for (const auto& v : value) {
543 *out->mutable_list()->add_func() = v;
544 }
545 }
546
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b)547 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
548 return AreAttrValuesEqual(a, b, AreTensorProtosEqual);
549 }
550
AttrValueHash(const AttrValue & a)551 uint64 AttrValueHash(const AttrValue& a) {
552 return AttrValueHash(a, TensorProtoHash);
553 }
554
FastAreAttrValuesEqual(const AttrValue & a,const AttrValue & b)555 bool FastAreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
556 return AreAttrValuesEqual(a, b, FastAreTensorProtosEqual);
557 }
558
FastAttrValueHash(const AttrValue & a)559 uint64 FastAttrValueHash(const AttrValue& a) {
560 return AttrValueHash(a, FastTensorProtoHash);
561 }
562
HasPlaceHolder(const AttrValue & val)563 bool HasPlaceHolder(const AttrValue& val) {
564 switch (val.value_case()) {
565 case AttrValue::kList: {
566 for (const NameAttrList& func : val.list().func()) {
567 for (const auto& p : func.attr()) {
568 if (HasPlaceHolder(p.second)) {
569 return true;
570 }
571 }
572 }
573 break;
574 }
575 case AttrValue::kFunc:
576 for (const auto& p : val.func().attr()) {
577 if (HasPlaceHolder(p.second)) {
578 return true;
579 }
580 }
581 break;
582 case AttrValue::kPlaceholder:
583 return true;
584 default:
585 break;
586 }
587 return false;
588 }
589
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)590 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
591 AttrValue* value) {
592 switch (value->value_case()) {
593 case AttrValue::kList: {
594 for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
595 for (auto& p : *func.mutable_attr()) {
596 if (!SubstitutePlaceholders(substitute, &p.second)) {
597 return false;
598 }
599 }
600 }
601 break;
602 }
603 case AttrValue::kFunc:
604 for (auto& p : *(value->mutable_func()->mutable_attr())) {
605 if (!SubstitutePlaceholders(substitute, &p.second)) {
606 return false;
607 }
608 }
609 break;
610 case AttrValue::kPlaceholder:
611 return substitute(value->placeholder(), value);
612 case AttrValue::VALUE_NOT_SET:
613 return false;
614 default:
615 break;
616 }
617 return true;
618 }
619
620 } // namespace tensorflow
621