1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/attr_value_util.h"
17
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/attr_value.pb_text.h"
24 #include "tensorflow/core/framework/tensor.pb_text.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb_text.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/lib/strings/proto_serialization.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/fingerprint.h"
34 #include "tensorflow/core/platform/protobuf.h"
35
36 namespace tensorflow {
37 namespace {
38
39 // Do not construct large tensors to compute their hash or compare for equality.
40 constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
41
42 // Limit nesting of tensors to 100 deep to prevent memory overflow.
43 constexpr int kMaxTensorNestDepth = 100;
44
45 // Return the size of the tensor represented by this TensorProto. If shape is
46 // not fully defined return -1.
TensorByteSize(const TensorProto & t)47 int64_t TensorByteSize(const TensorProto& t) {
48 // num_elements returns -1 if shape is not fully defined.
49 int64_t num_elems = PartialTensorShape(t.tensor_shape()).num_elements();
50 return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
51 }
52
53 // Compute TensorProto hash by creating a Tensor, serializing it as tensor
54 // content, and computing a hash of it's string representation. This is unsafe
55 // operation, because large tensors can be represented as TensorProto, but can't
56 // be serialized to tensor content.
TensorProtoHash(const TensorProto & tp)57 uint64 TensorProtoHash(const TensorProto& tp) {
58 Tensor tensor(tp.dtype());
59 bool success = tensor.FromProto(tp);
60 DCHECK(success);
61 TensorProto p;
62 tensor.AsProtoTensorContent(&p);
63 return DeterministicProtoHash64(p);
64 }
65
66 // Do not create large tensors in memory, compute hash based on TensorProto
67 // string representation. Tensors with identical content potentially can have a
68 // different hash code if they are defined with different TensorProto
69 // representations.
FastTensorProtoHash(const TensorProto & tp)70 uint64 FastTensorProtoHash(const TensorProto& tp) {
71 if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
72 return DeterministicProtoHash64(tp);
73 } else {
74 return TensorProtoHash(tp);
75 }
76 }
77
AreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs,bool allow_false_negatives)78 bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs,
79 bool allow_false_negatives) {
80 // A small TensorProto can expand into a giant Tensor. So we avoid
81 // conversion to an actual Tensor if we can quickly rule out equality
82 // by comparing the Tensor size since different sized Tensors are definitely
83 // different.
84 const int64_t lhs_tensor_bytes = TensorByteSize(lhs);
85 const int64_t rhs_tensor_bytes = TensorByteSize(rhs);
86 if (lhs_tensor_bytes != rhs_tensor_bytes) {
87 return false;
88 }
89
90 // If the TensorProto representation expands into a much bigger Tensor,
91 // we have a fast-path that first compares the protos.
92 const int64_t lhs_proto_bytes = lhs.ByteSizeLong();
93 const bool large_expansion =
94 (lhs_proto_bytes < 512 && lhs_tensor_bytes > 4096);
95
96 // If the tensor is very large, we'll only compare the proto representation if
97 // false negatives are allowed. This may miss some equivalent tensors whose
98 // actual tensor values are the same but which are described by different
99 // TensorProtos. This avoids construction of large protos in memory.
100 const bool only_compare_proto =
101 (allow_false_negatives && lhs_tensor_bytes > kMaxAttrValueTensorByteSize);
102 if (large_expansion || only_compare_proto) {
103 if (AreSerializedProtosEqual(lhs, rhs))
104 return true;
105 else if (only_compare_proto)
106 return false;
107 }
108
109 // Finally, compare them by constructing Tensors and serializing them back.
110 // There are multiple equivalent representations of attr values containing
111 // TensorProtos. Comparing Tensor objects is pretty tricky. This is unsafe
112 // operation, because large tensors can be represented as TensorProto, but
113 // can't be serialized to tensor content.
114 Tensor lhs_t(lhs.dtype());
115 bool success = lhs_t.FromProto(lhs);
116 if (!success) {
117 return false;
118 }
119
120 Tensor rhs_t(rhs.dtype());
121 success = rhs_t.FromProto(rhs);
122 if (!success) {
123 return false;
124 }
125
126 TensorProto lhs_tp;
127 lhs_t.AsProtoTensorContent(&lhs_tp);
128
129 TensorProto rhs_tp;
130 rhs_t.AsProtoTensorContent(&rhs_tp);
131
132 return AreSerializedProtosEqual(lhs_tp, rhs_tp);
133 }
134
135 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
136
AttrValueHash(const AttrValue & a,const TensorProtoHasher & tensor_hash)137 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
138 if (a.has_tensor()) return tensor_hash(a.tensor());
139
140 if (a.has_func()) {
141 const NameAttrList& func = a.func();
142 uint64 h = Hash64(func.name());
143 std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
144 for (const auto& pair : map) {
145 h = Hash64(pair.first.data(), pair.first.size(), h);
146 h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
147 }
148 return h;
149 }
150
151 // If `a` is not a tensor or func, get a hash of serialized string.
152 return DeterministicProtoHash64(a);
153 }
154
SummarizeString(const string & str)155 string SummarizeString(const string& str) {
156 string escaped = absl::CEscape(str);
157
158 // If the string is long, replace the middle with ellipses.
159 constexpr int kMaxStringSummarySize = 80;
160 if (escaped.size() >= kMaxStringSummarySize) {
161 StringPiece prefix(escaped);
162 StringPiece suffix = prefix;
163 prefix.remove_suffix(escaped.size() - 10);
164 suffix.remove_prefix(escaped.size() - 10);
165 return strings::StrCat("\"", prefix, "...", suffix, "\"");
166 } else {
167 return strings::StrCat("\"", escaped, "\"");
168 }
169 }
170
SummarizeTensor(const TensorProto & tensor_proto)171 string SummarizeTensor(const TensorProto& tensor_proto) {
172 Tensor t;
173 if (!t.FromProto(tensor_proto)) {
174 return strings::StrCat(
175 "<Invalid TensorProto: ", tensor_proto.ShortDebugString(), ">");
176 }
177 return t.DebugString();
178 }
179
SummarizeFunc(const NameAttrList & func)180 string SummarizeFunc(const NameAttrList& func) {
181 std::vector<string> entries;
182 for (const auto& p : func.attr()) {
183 entries.push_back(
184 strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
185 }
186 std::sort(entries.begin(), entries.end());
187 return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
188 }
189
ParseAttrValueHelper_TensorNestsUnderLimit(int limit,string to_parse)190 bool ParseAttrValueHelper_TensorNestsUnderLimit(int limit, string to_parse) {
191 int nests = 0;
192 int maxed_out = to_parse.length();
193 int open_curly = to_parse.find('{');
194 int open_bracket = to_parse.find('<');
195 int close_curly = to_parse.find('}');
196 int close_bracket = to_parse.find('>');
197 if (open_curly == -1) {
198 open_curly = maxed_out;
199 }
200 if (open_bracket == -1) {
201 open_bracket = maxed_out;
202 }
203 int min = std::min(open_curly, open_bracket);
204 do {
205 if (open_curly == maxed_out && open_bracket == maxed_out) {
206 return true;
207 }
208 if (min == open_curly) {
209 nests += 1;
210 open_curly = to_parse.find('{', open_curly + 1);
211 if (open_curly == -1) {
212 open_curly = maxed_out;
213 }
214 } else if (min == open_bracket) {
215 nests += 1;
216 open_bracket = to_parse.find('<', open_bracket + 1);
217 if (open_bracket == -1) {
218 open_bracket = maxed_out;
219 }
220 } else if (min == close_curly) {
221 nests -= 1;
222 close_curly = to_parse.find('}', close_curly + 1);
223 if (close_curly == -1) {
224 close_curly = maxed_out;
225 }
226 } else if (min == close_bracket) {
227 nests -= 1;
228 close_bracket = to_parse.find('>', close_bracket + 1);
229 if (close_bracket == -1) {
230 close_bracket = maxed_out;
231 }
232 }
233 min = std::min({open_curly, open_bracket, close_curly, close_bracket});
234 } while (nests < 100);
235 return false;
236 }
237
238 } // namespace
239
SummarizeAttrValue(const AttrValue & attr_value)240 string SummarizeAttrValue(const AttrValue& attr_value) {
241 switch (attr_value.value_case()) {
242 case AttrValue::kS:
243 return SummarizeString(attr_value.s());
244 case AttrValue::kI:
245 return strings::StrCat(attr_value.i());
246 case AttrValue::kF:
247 return strings::StrCat(attr_value.f());
248 case AttrValue::kB:
249 return attr_value.b() ? "true" : "false";
250 case AttrValue::kType:
251 return EnumName_DataType(attr_value.type());
252 case AttrValue::kShape:
253 return PartialTensorShape::DebugString(attr_value.shape());
254 case AttrValue::kTensor:
255 return SummarizeTensor(attr_value.tensor());
256 case AttrValue::kList: {
257 std::vector<string> pieces;
258 if (attr_value.list().s_size() > 0) {
259 for (int i = 0; i < attr_value.list().s_size(); ++i) {
260 pieces.push_back(SummarizeString(attr_value.list().s(i)));
261 }
262 } else if (attr_value.list().i_size() > 0) {
263 for (int i = 0; i < attr_value.list().i_size(); ++i) {
264 pieces.push_back(strings::StrCat(attr_value.list().i(i)));
265 }
266 } else if (attr_value.list().f_size() > 0) {
267 for (int i = 0; i < attr_value.list().f_size(); ++i) {
268 pieces.push_back(strings::StrCat(attr_value.list().f(i)));
269 }
270 } else if (attr_value.list().b_size() > 0) {
271 for (int i = 0; i < attr_value.list().b_size(); ++i) {
272 pieces.push_back(attr_value.list().b(i) ? "true" : "false");
273 }
274 } else if (attr_value.list().type_size() > 0) {
275 for (int i = 0; i < attr_value.list().type_size(); ++i) {
276 pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
277 }
278 } else if (attr_value.list().shape_size() > 0) {
279 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
280 pieces.push_back(
281 TensorShape::DebugString(attr_value.list().shape(i)));
282 }
283 } else if (attr_value.list().tensor_size() > 0) {
284 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
285 pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
286 }
287 } else if (attr_value.list().func_size() > 0) {
288 for (int i = 0; i < attr_value.list().func_size(); ++i) {
289 pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
290 }
291 }
292 constexpr int kMaxListSummarySize = 15;
293 if (pieces.size() >= kMaxListSummarySize) {
294 pieces[5] = strings::StrCat(Fingerprint64(
295 absl::StrJoin(pieces.begin() + 5, pieces.end() - 5, ",")));
296 pieces.erase(pieces.begin() + 6, pieces.end() - 5);
297 }
298 return strings::StrCat("[", absl::StrJoin(pieces, ", "), "]");
299 }
300 case AttrValue::kFunc: {
301 return SummarizeFunc(attr_value.func());
302 }
303 case AttrValue::kPlaceholder:
304 return strings::StrCat("$", attr_value.placeholder());
305 case AttrValue::VALUE_NOT_SET:
306 return "<Unknown AttrValue type>";
307 }
308 return "<Unknown AttrValue type>"; // Prevent missing return warning
309 }
310
AttrValueHasType(const AttrValue & attr_value,StringPiece type)311 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
312 int num_set = 0;
313
314 #define VALIDATE_FIELD(name, type_string, oneof_case) \
315 do { \
316 if (attr_value.has_list()) { \
317 if (attr_value.list().name##_size() > 0) { \
318 if (type != "list(" type_string ")") { \
319 return errors::InvalidArgument( \
320 "AttrValue had value with type 'list(" type_string ")' when '", \
321 type, "' expected"); \
322 } \
323 ++num_set; \
324 } \
325 } else if (attr_value.value_case() == AttrValue::oneof_case) { \
326 if (type != type_string) { \
327 return errors::InvalidArgument( \
328 "AttrValue had value with type '" type_string "' when '", type, \
329 "' expected"); \
330 } \
331 ++num_set; \
332 } \
333 } while (false)
334
335 VALIDATE_FIELD(s, "string", kS);
336 VALIDATE_FIELD(i, "int", kI);
337 VALIDATE_FIELD(f, "float", kF);
338 VALIDATE_FIELD(b, "bool", kB);
339 VALIDATE_FIELD(type, "type", kType);
340 VALIDATE_FIELD(shape, "shape", kShape);
341 VALIDATE_FIELD(tensor, "tensor", kTensor);
342 VALIDATE_FIELD(func, "func", kFunc);
343
344 #undef VALIDATE_FIELD
345
346 if (attr_value.value_case() == AttrValue::kPlaceholder) {
347 return errors::InvalidArgument(
348 "AttrValue had value with unexpected type 'placeholder'");
349 }
350
351 // If the attr type is 'list', we expect attr_value.has_list() to be
352 // true. However, proto3's attr_value.has_list() can be false when
353 // set to an empty list for GraphDef versions <= 4. So we simply
354 // check if has_list is false and some other field in attr_value is
355 // set to flag the error. This test can be made more strict once
356 // support for GraphDef versions <= 4 is dropped.
357 if (absl::StartsWith(type, "list(") && !attr_value.has_list()) {
358 if (num_set) {
359 return errors::InvalidArgument(
360 "AttrValue missing value with expected type '", type, "'");
361 } else {
362 // Indicate that we have a list, but an empty one.
363 ++num_set;
364 }
365 }
366
367 // Okay to have an empty list, but not to be missing a non-list value.
368 if (num_set == 0 && !absl::StartsWith(type, "list(")) {
369 return errors::InvalidArgument(
370 "AttrValue missing value with expected type '", type, "'");
371 }
372
373 // Ref types and DT_INVALID are illegal, and DataTypes must
374 // be a valid enum type.
375 if (type == "type") {
376 if (!DataType_IsValid(attr_value.type())) {
377 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
378 attr_value.type());
379 }
380 if (IsRefType(attr_value.type())) {
381 return errors::InvalidArgument(
382 "AttrValue must not have reference type value of ",
383 DataTypeString(attr_value.type()));
384 }
385 if (attr_value.type() == DT_INVALID) {
386 return errors::InvalidArgument("AttrValue has invalid DataType");
387 }
388 } else if (type == "list(type)") {
389 for (auto as_int : attr_value.list().type()) {
390 const DataType dtype = static_cast<DataType>(as_int);
391 if (!DataType_IsValid(dtype)) {
392 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
393 as_int);
394 }
395 if (IsRefType(dtype)) {
396 return errors::InvalidArgument(
397 "AttrValue must not have reference type value of ",
398 DataTypeString(dtype));
399 }
400 if (dtype == DT_INVALID) {
401 return errors::InvalidArgument("AttrValue contains invalid DataType");
402 }
403 }
404 }
405
406 return OkStatus();
407 }
408
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)409 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
410 // Parse type.
411 string field_name;
412 bool is_list = absl::ConsumePrefix(&type, "list(");
413 if (absl::ConsumePrefix(&type, "string")) {
414 field_name = "s";
415 } else if (absl::ConsumePrefix(&type, "int")) {
416 field_name = "i";
417 } else if (absl::ConsumePrefix(&type, "float")) {
418 field_name = "f";
419 } else if (absl::ConsumePrefix(&type, "bool")) {
420 field_name = "b";
421 } else if (absl::ConsumePrefix(&type, "type")) {
422 field_name = "type";
423 } else if (absl::ConsumePrefix(&type, "shape")) {
424 field_name = "shape";
425 } else if (absl::ConsumePrefix(&type, "tensor")) {
426 field_name = "tensor";
427 } else if (absl::ConsumePrefix(&type, "func")) {
428 field_name = "func";
429 } else if (absl::ConsumePrefix(&type, "placeholder")) {
430 field_name = "placeholder";
431 } else {
432 return false;
433 }
434 if (is_list && !absl::ConsumePrefix(&type, ")")) {
435 return false;
436 }
437
438 // Construct a valid text proto message to parse.
439 string to_parse;
440 if (is_list) {
441 // TextFormat parser considers "i: 7" to be the same as "i: [7]",
442 // but we only want to allow list values with [].
443 StringPiece cleaned = text;
444 str_util::RemoveLeadingWhitespace(&cleaned);
445 str_util::RemoveTrailingWhitespace(&cleaned);
446 if (cleaned.size() < 2 || cleaned[0] != '[' ||
447 cleaned[cleaned.size() - 1] != ']') {
448 return false;
449 }
450 cleaned.remove_prefix(1);
451 str_util::RemoveLeadingWhitespace(&cleaned);
452 if (cleaned.size() == 1) {
453 // User wrote "[]", so return empty list without invoking the TextFormat
454 // parse which returns an error for "i: []".
455 out->Clear();
456 out->mutable_list();
457 return true;
458 }
459 to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
460 } else {
461 to_parse = strings::StrCat(field_name, ": ", text);
462 }
463 if (field_name == "tensor") {
464 if (!ParseAttrValueHelper_TensorNestsUnderLimit(kMaxTensorNestDepth,
465 to_parse)) {
466 return false;
467 }
468 }
469 return ProtoParseFromString(to_parse, out);
470 }
471
SetAttrValue(const AttrValue & value,AttrValue * out)472 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
473
474 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
475 void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
476
477 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
478 void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
479 out->mutable_list()->Clear(); /* create list() even if value empty */ \
480 for (const auto& v : value) { \
481 out->mutable_list()->add_##FIELD(v); \
482 } \
483 }
484
485 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
486 DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
487 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
488
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)489 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
490 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
491 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
492 DEFINE_SET_ATTR_VALUE_BOTH(int64_t, i)
493 DEFINE_SET_ATTR_VALUE_BOTH(int32_t, i)
494 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
495 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
496 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
497 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
498 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
499 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
500
501 void SetAttrValue(const tstring& value, AttrValue* out) {
502 out->set_s(value.data(), value.size());
503 }
504
SetAttrValue(gtl::ArraySlice<tstring> value,AttrValue * out)505 void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
506 out->mutable_list()->Clear();
507 for (const auto& v : value) {
508 out->mutable_list()->add_s(v.data(), v.size());
509 }
510 }
511
SetAttrValue(StringPiece value,AttrValue * out)512 void SetAttrValue(StringPiece value, AttrValue* out) {
513 out->set_s(value.data(), value.size());
514 }
515
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)516 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
517 out->mutable_list()->Clear(); // Create list() even if value empty.
518 for (const auto& v : value) {
519 out->mutable_list()->add_s(v.data(), v.size());
520 }
521 }
522
MoveAttrValue(std::vector<string> && value,AttrValue * out)523 void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
524 out->mutable_list()->Clear(); // Create list() even if value empty.
525 for (auto& v : value) {
526 out->mutable_list()->add_s(std::move(v));
527 }
528 }
529
SetAttrValue(const TensorShape & value,AttrValue * out)530 void SetAttrValue(const TensorShape& value, AttrValue* out) {
531 value.AsProto(out->mutable_shape());
532 }
533
SetAttrValue(const TensorShapeProto & value,AttrValue * out)534 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
535 *out->mutable_shape() = value;
536 }
537
SetAttrValue(const PartialTensorShape & value,AttrValue * out)538 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
539 value.AsProto(out->mutable_shape());
540 }
541
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)542 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
543 out->mutable_list()->Clear(); // Create list() even if value empty.
544 for (const auto& v : value) {
545 v.AsProto(out->mutable_list()->add_shape());
546 }
547 }
548
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)549 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
550 out->mutable_list()->Clear(); // Create list() even if value empty.
551 for (const auto& v : value) {
552 *out->mutable_list()->add_shape() = v;
553 }
554 }
555
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)556 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
557 AttrValue* out) {
558 out->mutable_list()->Clear(); // Create list() even if value empty.
559 for (const auto& v : value) {
560 v.AsProto(out->mutable_list()->add_shape());
561 }
562 }
563
SetAttrValue(const Tensor & value,AttrValue * out)564 void SetAttrValue(const Tensor& value, AttrValue* out) {
565 if (value.NumElements() > 1) {
566 value.AsProtoTensorContent(out->mutable_tensor());
567 } else {
568 value.AsProtoField(out->mutable_tensor());
569 }
570 }
571
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)572 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
573 out->mutable_list()->Clear(); // Create list() even if value empty.
574 for (const auto& v : value) {
575 if (v.NumElements() > 1) {
576 v.AsProtoTensorContent(out->mutable_list()->add_tensor());
577 } else {
578 v.AsProtoField(out->mutable_list()->add_tensor());
579 }
580 }
581 }
582
SetAttrValue(const TensorProto & value,AttrValue * out)583 void SetAttrValue(const TensorProto& value, AttrValue* out) {
584 *out->mutable_tensor() = value;
585 }
586
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)587 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
588 out->mutable_list()->Clear(); // Create list() even if value empty.
589 for (const auto& v : value) {
590 *out->mutable_list()->add_tensor() = v;
591 }
592 }
593
SetAttrValue(const NameAttrList & value,AttrValue * out)594 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
595 *out->mutable_func() = value;
596 }
597
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)598 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
599 out->mutable_list()->Clear(); // Create list() even if value empty.
600 for (const auto& v : value) {
601 *out->mutable_list()->add_func() = v;
602 }
603 }
604
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b,bool allow_false_negatives)605 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
606 bool allow_false_negatives) {
607 if (a.type() != b.type()) {
608 return false;
609 } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
610 return a.type() == b.type();
611 }
612
613 if (a.has_tensor() != b.has_tensor()) {
614 return false;
615 } else if (a.has_tensor() && b.has_tensor()) {
616 return AreTensorProtosEqual(a.tensor(), b.tensor(), allow_false_negatives);
617 }
618
619 // `func` field contains a nested AttrValue. Compare such AttrValues
620 // recursively.
621 if (a.has_func() != b.has_func()) {
622 return false;
623 } else if (a.has_func() && b.has_func()) {
624 const NameAttrList& af = a.func();
625 const NameAttrList& bf = b.func();
626 if (af.name() != bf.name()) return false;
627 std::unordered_map<string, AttrValue> am(af.attr().begin(),
628 af.attr().end());
629 for (const auto& bm_pair : bf.attr()) {
630 const auto& iter = am.find(bm_pair.first);
631 if (iter == am.end()) return false;
632 if (!AreAttrValuesEqual(iter->second, bm_pair.second,
633 allow_false_negatives))
634 return false;
635 am.erase(iter);
636 }
637 if (!am.empty()) return false;
638 return true;
639 }
640
641 // All other fields in AttrValue have deterministic representations.
642 // It is safe to compare their serialized strings.
643 return AreSerializedProtosEqual(a, b);
644 }
645
AttrValueHash(const AttrValue & a)646 uint64 AttrValueHash(const AttrValue& a) {
647 return AttrValueHash(a, TensorProtoHash);
648 }
649
FastAttrValueHash(const AttrValue & a)650 uint64 FastAttrValueHash(const AttrValue& a) {
651 return AttrValueHash(a, FastTensorProtoHash);
652 }
653
HasPlaceHolder(const AttrValue & val)654 bool HasPlaceHolder(const AttrValue& val) {
655 switch (val.value_case()) {
656 case AttrValue::kList: {
657 for (const NameAttrList& func : val.list().func()) {
658 for (const auto& p : func.attr()) {
659 if (HasPlaceHolder(p.second)) {
660 return true;
661 }
662 }
663 }
664 break;
665 }
666 case AttrValue::kFunc:
667 for (const auto& p : val.func().attr()) {
668 if (HasPlaceHolder(p.second)) {
669 return true;
670 }
671 }
672 break;
673 case AttrValue::kPlaceholder:
674 return true;
675 default:
676 break;
677 }
678 return false;
679 }
680
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)681 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
682 AttrValue* value) {
683 switch (value->value_case()) {
684 case AttrValue::kList: {
685 for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
686 for (auto& p : *func.mutable_attr()) {
687 if (!SubstitutePlaceholders(substitute, &p.second)) {
688 return false;
689 }
690 }
691 }
692 break;
693 }
694 case AttrValue::kFunc:
695 for (auto& p : *(value->mutable_func()->mutable_attr())) {
696 if (!SubstitutePlaceholders(substitute, &p.second)) {
697 return false;
698 }
699 }
700 break;
701 case AttrValue::kPlaceholder:
702 return substitute(value->placeholder(), value);
703 case AttrValue::VALUE_NOT_SET:
704 return false;
705 default:
706 break;
707 }
708 return true;
709 }
710
711 } // namespace tensorflow
712