1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/attr_value_util.h"
17
18 #include <string>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/strings/escaping.h"
23 #include "tensorflow/core/framework/attr_value.pb_text.h"
24 #include "tensorflow/core/framework/tensor.pb_text.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb_text.h"
28 #include "tensorflow/core/lib/core/errors.h"
29 #include "tensorflow/core/lib/core/stringpiece.h"
30 #include "tensorflow/core/lib/hash/hash.h"
31 #include "tensorflow/core/lib/strings/proto_serialization.h"
32 #include "tensorflow/core/lib/strings/str_util.h"
33 #include "tensorflow/core/platform/protobuf.h"
34
35 namespace tensorflow {
36 namespace {
37
38 // Do not construct large tensors to compute their hash or compare for equality.
39 constexpr int kMaxAttrValueTensorByteSize = 32 * 1024 * 1024; // 32mb
40
41 // Return the size of the tensor represented by this TensorProto. If shape is
42 // not fully defined return -1.
TensorByteSize(const TensorProto & t)43 int64 TensorByteSize(const TensorProto& t) {
44 // num_elements returns -1 if shape is not fully defined.
45 int64 num_elems = TensorShape(t.tensor_shape()).num_elements();
46 return num_elems < 0 ? -1 : num_elems * DataTypeSize(t.dtype());
47 }
48
49 // Compute TensorProto hash by creating a Tensor, serializing it as tensor
50 // content, and computing a hash of it's string representation. This is unsafe
51 // operation, because large tensors can be represented as TensorProto, but can't
52 // be serialized to tensor content.
TensorProtoHash(const TensorProto & tp)53 uint64 TensorProtoHash(const TensorProto& tp) {
54 Tensor tensor(tp.dtype());
55 bool success = tensor.FromProto(tp);
56 DCHECK(success);
57 TensorProto p;
58 tensor.AsProtoTensorContent(&p);
59 return DeterministicProtoHash64(p);
60 }
61
62 // Do not create large tensors in memory, compute hash based on TensorProto
63 // string representation. Tensors with identical content potentially can have a
64 // different hash code if they are defined with different TensorProto
65 // representations.
FastTensorProtoHash(const TensorProto & tp)66 uint64 FastTensorProtoHash(const TensorProto& tp) {
67 if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
68 return DeterministicProtoHash64(tp);
69 } else {
70 return TensorProtoHash(tp);
71 }
72 }
73
74 // There are multiple equivalent representations of attr values containing
75 // TensorProtos. Compare them by constructing Tensors and serializing them
76 // back. Comparing Tensor objects is pretty tricky. This is unsafe operation,
77 // because large tensors can be represented as TensorProto, but can't be
78 // serialized to tensor content.
AreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)79 bool AreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
80 Tensor lhs_t(lhs.dtype());
81 bool success = lhs_t.FromProto(lhs);
82 DCHECK(success);
83
84 Tensor rhs_t(rhs.dtype());
85 success = rhs_t.FromProto(rhs);
86 DCHECK(success);
87
88 TensorProto lhs_tp;
89 lhs_t.AsProtoTensorContent(&lhs_tp);
90
91 TensorProto rhs_tp;
92 rhs_t.AsProtoTensorContent(&rhs_tp);
93
94 return AreSerializedProtosEqual(lhs_tp, rhs_tp);
95 }
96
97 // Do not construct large tensors in memory, compare equality using TensorProto
98 // string representation. Tensors with identical content potentially can have
99 // different tensor proto representation.
FastAreTensorProtosEqual(const TensorProto & lhs,const TensorProto & rhs)100 bool FastAreTensorProtosEqual(const TensorProto& lhs, const TensorProto& rhs) {
101 // A small TensorProto can expand into a giant Tensor. So we avoid
102 // conversion to an actual Tensor if we can quickly rule out equality
103 // by comparing the Tensor size since different sized Tensors are definitely
104 // different.
105 const int64 lhs_tensor_bytes = TensorByteSize(lhs);
106 const int64 rhs_tensor_bytes = TensorByteSize(rhs);
107 if (lhs_tensor_bytes != rhs_tensor_bytes) {
108 return false;
109 }
110
111 // If the tensor is very large, we'll only compare the proto representation
112 // (even though this may miss some equivalent tensors whose actual tensor
113 // values are the same but which are described by different TensorProtos).
114 if (lhs_tensor_bytes > kMaxAttrValueTensorByteSize) {
115 return AreSerializedProtosEqual(lhs, rhs);
116 }
117
118 // If the TensorProto representation expands into a much bigger Tensor,
119 // we have a fast-path that first compares the protos.
120 const int64 lhs_proto_bytes = lhs.ByteSizeLong();
121 const bool large_expansion =
122 (lhs_proto_bytes < 512 && lhs_tensor_bytes > 4096);
123 if (large_expansion && AreSerializedProtosEqual(lhs, rhs)) {
124 return true;
125 }
126
127 // Fall back to the general code in AreTensorProtosEqual.
128 return AreTensorProtosEqual(lhs, rhs);
129 }
130
131 using TensorProtoHasher = std::function<uint64(const TensorProto&)>;
132
AttrValueHash(const AttrValue & a,const TensorProtoHasher & tensor_hash)133 uint64 AttrValueHash(const AttrValue& a, const TensorProtoHasher& tensor_hash) {
134 if (a.has_tensor()) return tensor_hash(a.tensor());
135
136 if (a.has_func()) {
137 const NameAttrList& func = a.func();
138 uint64 h = Hash64(func.name());
139 std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
140 for (const auto& pair : map) {
141 h = Hash64(pair.first.data(), pair.first.size(), h);
142 h = Hash64Combine(AttrValueHash(pair.second, tensor_hash), h);
143 }
144 return h;
145 }
146
147 // If `a` is not a tensor or func, get a hash of serialized string.
148 return DeterministicProtoHash64(a);
149 }
150
151 template <typename TensorProtosEquality>
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b,TensorProtosEquality tensor_equality)152 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
153 TensorProtosEquality tensor_equality) {
154 if (a.type() != b.type()) {
155 return false;
156 } else if (a.type() != DT_INVALID && b.type() != DT_INVALID) {
157 return a.type() == b.type();
158 }
159
160 if (a.has_tensor() != b.has_tensor()) {
161 return false;
162 } else if (a.has_tensor() && b.has_tensor()) {
163 return tensor_equality(a.tensor(), b.tensor());
164 }
165
166 // `func` field contains a nested AttrValue. Compare such AttrValues
167 // recursively.
168 if (a.has_func() != b.has_func()) {
169 return false;
170 } else if (a.has_func() && b.has_func()) {
171 const NameAttrList& af = a.func();
172 const NameAttrList& bf = b.func();
173 if (af.name() != bf.name()) return false;
174 std::unordered_map<string, AttrValue> am(af.attr().begin(),
175 af.attr().end());
176 for (const auto& bm_pair : bf.attr()) {
177 const auto& iter = am.find(bm_pair.first);
178 if (iter == am.end()) return false;
179 if (!AreAttrValuesEqual(iter->second, bm_pair.second, tensor_equality))
180 return false;
181 am.erase(iter);
182 }
183 if (!am.empty()) return false;
184 return true;
185 }
186
187 // All other fields in AttrValue have deterministic representations.
188 // It is safe to compare their serialized strings.
189 return AreSerializedProtosEqual(a, b);
190 }
191
SummarizeString(const string & str)192 string SummarizeString(const string& str) {
193 string escaped = absl::CEscape(str);
194
195 // If the string is long, replace the middle with ellipses.
196 constexpr int kMaxStringSummarySize = 80;
197 if (escaped.size() >= kMaxStringSummarySize) {
198 StringPiece prefix(escaped);
199 StringPiece suffix = prefix;
200 prefix.remove_suffix(escaped.size() - 10);
201 suffix.remove_prefix(escaped.size() - 10);
202 return strings::StrCat("\"", prefix, "...", suffix, "\"");
203 } else {
204 return strings::StrCat("\"", escaped, "\"");
205 }
206 }
207
SummarizeTensor(const TensorProto & tensor_proto)208 string SummarizeTensor(const TensorProto& tensor_proto) {
209 Tensor t;
210 if (!t.FromProto(tensor_proto)) {
211 return strings::StrCat(
212 "<Invalid TensorProto: ", tensor_proto.ShortDebugString(), ">");
213 }
214 return t.DebugString();
215 }
216
SummarizeFunc(const NameAttrList & func)217 string SummarizeFunc(const NameAttrList& func) {
218 std::vector<string> entries;
219 for (auto p : func.attr()) {
220 entries.push_back(
221 strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
222 }
223 std::sort(entries.begin(), entries.end());
224 return strings::StrCat(func.name(), "[", absl::StrJoin(entries, ", "), "]");
225 }
226
227 } // namespace
228
SummarizeAttrValue(const AttrValue & attr_value)229 string SummarizeAttrValue(const AttrValue& attr_value) {
230 switch (attr_value.value_case()) {
231 case AttrValue::kS:
232 return SummarizeString(attr_value.s());
233 case AttrValue::kI:
234 return strings::StrCat(attr_value.i());
235 case AttrValue::kF:
236 return strings::StrCat(attr_value.f());
237 case AttrValue::kB:
238 return attr_value.b() ? "true" : "false";
239 case AttrValue::kType:
240 return EnumName_DataType(attr_value.type());
241 case AttrValue::kShape:
242 return PartialTensorShape::DebugString(attr_value.shape());
243 case AttrValue::kTensor:
244 return SummarizeTensor(attr_value.tensor());
245 case AttrValue::kList: {
246 std::vector<string> pieces;
247 if (attr_value.list().s_size() > 0) {
248 for (int i = 0; i < attr_value.list().s_size(); ++i) {
249 pieces.push_back(SummarizeString(attr_value.list().s(i)));
250 }
251 } else if (attr_value.list().i_size() > 0) {
252 for (int i = 0; i < attr_value.list().i_size(); ++i) {
253 pieces.push_back(strings::StrCat(attr_value.list().i(i)));
254 }
255 } else if (attr_value.list().f_size() > 0) {
256 for (int i = 0; i < attr_value.list().f_size(); ++i) {
257 pieces.push_back(strings::StrCat(attr_value.list().f(i)));
258 }
259 } else if (attr_value.list().b_size() > 0) {
260 for (int i = 0; i < attr_value.list().b_size(); ++i) {
261 pieces.push_back(attr_value.list().b(i) ? "true" : "false");
262 }
263 } else if (attr_value.list().type_size() > 0) {
264 for (int i = 0; i < attr_value.list().type_size(); ++i) {
265 pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
266 }
267 } else if (attr_value.list().shape_size() > 0) {
268 for (int i = 0; i < attr_value.list().shape_size(); ++i) {
269 pieces.push_back(
270 TensorShape::DebugString(attr_value.list().shape(i)));
271 }
272 } else if (attr_value.list().tensor_size() > 0) {
273 for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
274 pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
275 }
276 } else if (attr_value.list().func_size() > 0) {
277 for (int i = 0; i < attr_value.list().func_size(); ++i) {
278 pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
279 }
280 }
281 constexpr int kMaxListSummarySize = 15;
282 if (pieces.size() >= kMaxListSummarySize) {
283 pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
284 pieces[5] = "...";
285 }
286 return strings::StrCat("[", absl::StrJoin(pieces, ", "), "]");
287 }
288 case AttrValue::kFunc: {
289 return SummarizeFunc(attr_value.func());
290 }
291 case AttrValue::kPlaceholder:
292 return strings::StrCat("$", attr_value.placeholder());
293 case AttrValue::VALUE_NOT_SET:
294 return "<Unknown AttrValue type>";
295 }
296 return "<Unknown AttrValue type>"; // Prevent missing return warning
297 }
298
AttrValueHasType(const AttrValue & attr_value,StringPiece type)299 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
300 int num_set = 0;
301
302 #define VALIDATE_FIELD(name, type_string, oneof_case) \
303 do { \
304 if (attr_value.has_list()) { \
305 if (attr_value.list().name##_size() > 0) { \
306 if (type != "list(" type_string ")") { \
307 return errors::InvalidArgument( \
308 "AttrValue had value with type 'list(" type_string ")' when '", \
309 type, "' expected"); \
310 } \
311 ++num_set; \
312 } \
313 } else if (attr_value.value_case() == AttrValue::oneof_case) { \
314 if (type != type_string) { \
315 return errors::InvalidArgument( \
316 "AttrValue had value with type '" type_string "' when '", type, \
317 "' expected"); \
318 } \
319 ++num_set; \
320 } \
321 } while (false)
322
323 VALIDATE_FIELD(s, "string", kS);
324 VALIDATE_FIELD(i, "int", kI);
325 VALIDATE_FIELD(f, "float", kF);
326 VALIDATE_FIELD(b, "bool", kB);
327 VALIDATE_FIELD(type, "type", kType);
328 VALIDATE_FIELD(shape, "shape", kShape);
329 VALIDATE_FIELD(tensor, "tensor", kTensor);
330 VALIDATE_FIELD(func, "func", kFunc);
331
332 #undef VALIDATE_FIELD
333
334 if (attr_value.value_case() == AttrValue::kPlaceholder) {
335 return errors::InvalidArgument(
336 "AttrValue had value with unexpected type 'placeholder'");
337 }
338
339 // If the attr type is 'list', we expect attr_value.has_list() to be
340 // true. However, proto3's attr_value.has_list() can be false when
341 // set to an empty list for GraphDef versions <= 4. So we simply
342 // check if has_list is false and some other field in attr_value is
343 // set to flag the error. This test can be made more strict once
344 // support for GraphDef versions <= 4 is dropped.
345 if (absl::StartsWith(type, "list(") && !attr_value.has_list()) {
346 if (num_set) {
347 return errors::InvalidArgument(
348 "AttrValue missing value with expected type '", type, "'");
349 } else {
350 // Indicate that we have a list, but an empty one.
351 ++num_set;
352 }
353 }
354
355 // Okay to have an empty list, but not to be missing a non-list value.
356 if (num_set == 0 && !absl::StartsWith(type, "list(")) {
357 return errors::InvalidArgument(
358 "AttrValue missing value with expected type '", type, "'");
359 }
360
361 // Ref types and DT_INVALID are illegal, and DataTypes must
362 // be a valid enum type.
363 if (type == "type") {
364 if (!DataType_IsValid(attr_value.type())) {
365 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
366 attr_value.type());
367 }
368 if (IsRefType(attr_value.type())) {
369 return errors::InvalidArgument(
370 "AttrValue must not have reference type value of ",
371 DataTypeString(attr_value.type()));
372 }
373 if (attr_value.type() == DT_INVALID) {
374 return errors::InvalidArgument("AttrValue has invalid DataType");
375 }
376 } else if (type == "list(type)") {
377 for (auto as_int : attr_value.list().type()) {
378 const DataType dtype = static_cast<DataType>(as_int);
379 if (!DataType_IsValid(dtype)) {
380 return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
381 as_int);
382 }
383 if (IsRefType(dtype)) {
384 return errors::InvalidArgument(
385 "AttrValue must not have reference type value of ",
386 DataTypeString(dtype));
387 }
388 if (dtype == DT_INVALID) {
389 return errors::InvalidArgument("AttrValue contains invalid DataType");
390 }
391 }
392 }
393
394 return Status::OK();
395 }
396
ParseAttrValue(StringPiece type,StringPiece text,AttrValue * out)397 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
398 // Parse type.
399 string field_name;
400 bool is_list = absl::ConsumePrefix(&type, "list(");
401 if (absl::ConsumePrefix(&type, "string")) {
402 field_name = "s";
403 } else if (absl::ConsumePrefix(&type, "int")) {
404 field_name = "i";
405 } else if (absl::ConsumePrefix(&type, "float")) {
406 field_name = "f";
407 } else if (absl::ConsumePrefix(&type, "bool")) {
408 field_name = "b";
409 } else if (absl::ConsumePrefix(&type, "type")) {
410 field_name = "type";
411 } else if (absl::ConsumePrefix(&type, "shape")) {
412 field_name = "shape";
413 } else if (absl::ConsumePrefix(&type, "tensor")) {
414 field_name = "tensor";
415 } else if (absl::ConsumePrefix(&type, "func")) {
416 field_name = "func";
417 } else if (absl::ConsumePrefix(&type, "placeholder")) {
418 field_name = "placeholder";
419 } else {
420 return false;
421 }
422 if (is_list && !absl::ConsumePrefix(&type, ")")) {
423 return false;
424 }
425
426 // Construct a valid text proto message to parse.
427 string to_parse;
428 if (is_list) {
429 // TextFormat parser considers "i: 7" to be the same as "i: [7]",
430 // but we only want to allow list values with [].
431 StringPiece cleaned = text;
432 str_util::RemoveLeadingWhitespace(&cleaned);
433 str_util::RemoveTrailingWhitespace(&cleaned);
434 if (cleaned.size() < 2 || cleaned[0] != '[' ||
435 cleaned[cleaned.size() - 1] != ']') {
436 return false;
437 }
438 cleaned.remove_prefix(1);
439 str_util::RemoveLeadingWhitespace(&cleaned);
440 if (cleaned.size() == 1) {
441 // User wrote "[]", so return empty list without invoking the TextFormat
442 // parse which returns an error for "i: []".
443 out->Clear();
444 out->mutable_list();
445 return true;
446 }
447 to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
448 } else {
449 to_parse = strings::StrCat(field_name, ": ", text);
450 }
451
452 return ProtoParseFromString(to_parse, out);
453 }
454
SetAttrValue(const AttrValue & value,AttrValue * out)455 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
456
457 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
458 void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
459
460 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD) \
461 void SetAttrValue(ARG_TYPE value, AttrValue* out) { \
462 out->mutable_list()->Clear(); /* create list() even if value empty */ \
463 for (const auto& v : value) { \
464 out->mutable_list()->add_##FIELD(v); \
465 } \
466 }
467
468 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
469 DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
470 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
471
DEFINE_SET_ATTR_VALUE_ONE(const string &,s)472 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
473 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
474 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
475 DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
476 DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
477 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
478 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
479 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
480 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
481 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
482 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
483
484 void SetAttrValue(const tstring& value, AttrValue* out) {
485 out->set_s(value.data(), value.size());
486 }
487
SetAttrValue(gtl::ArraySlice<tstring> value,AttrValue * out)488 void SetAttrValue(gtl::ArraySlice<tstring> value, AttrValue* out) {
489 out->mutable_list()->Clear();
490 for (const auto& v : value) {
491 out->mutable_list()->add_s(v.data(), v.size());
492 }
493 }
494
SetAttrValue(StringPiece value,AttrValue * out)495 void SetAttrValue(StringPiece value, AttrValue* out) {
496 out->set_s(value.data(), value.size());
497 }
498
SetAttrValue(const gtl::ArraySlice<StringPiece> value,AttrValue * out)499 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
500 out->mutable_list()->Clear(); // Create list() even if value empty.
501 for (const auto& v : value) {
502 out->mutable_list()->add_s(v.data(), v.size());
503 }
504 }
505
MoveAttrValue(std::vector<string> && value,AttrValue * out)506 void MoveAttrValue(std::vector<string>&& value, AttrValue* out) {
507 out->mutable_list()->Clear(); // Create list() even if value empty.
508 for (auto& v : value) {
509 out->mutable_list()->add_s(std::move(v));
510 }
511 }
512
SetAttrValue(const TensorShape & value,AttrValue * out)513 void SetAttrValue(const TensorShape& value, AttrValue* out) {
514 value.AsProto(out->mutable_shape());
515 }
516
SetAttrValue(const TensorShapeProto & value,AttrValue * out)517 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
518 *out->mutable_shape() = value;
519 }
520
SetAttrValue(const PartialTensorShape & value,AttrValue * out)521 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
522 value.AsProto(out->mutable_shape());
523 }
524
SetAttrValue(const gtl::ArraySlice<TensorShape> value,AttrValue * out)525 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
526 out->mutable_list()->Clear(); // Create list() even if value empty.
527 for (const auto& v : value) {
528 v.AsProto(out->mutable_list()->add_shape());
529 }
530 }
531
SetAttrValue(gtl::ArraySlice<TensorShapeProto> value,AttrValue * out)532 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
533 out->mutable_list()->Clear(); // Create list() even if value empty.
534 for (const auto& v : value) {
535 *out->mutable_list()->add_shape() = v;
536 }
537 }
538
SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,AttrValue * out)539 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
540 AttrValue* out) {
541 out->mutable_list()->Clear(); // Create list() even if value empty.
542 for (const auto& v : value) {
543 v.AsProto(out->mutable_list()->add_shape());
544 }
545 }
546
SetAttrValue(const Tensor & value,AttrValue * out)547 void SetAttrValue(const Tensor& value, AttrValue* out) {
548 if (value.NumElements() > 1) {
549 value.AsProtoTensorContent(out->mutable_tensor());
550 } else {
551 value.AsProtoField(out->mutable_tensor());
552 }
553 }
554
SetAttrValue(const gtl::ArraySlice<Tensor> value,AttrValue * out)555 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
556 out->mutable_list()->Clear(); // Create list() even if value empty.
557 for (const auto& v : value) {
558 if (v.NumElements() > 1) {
559 v.AsProtoTensorContent(out->mutable_list()->add_tensor());
560 } else {
561 v.AsProtoField(out->mutable_list()->add_tensor());
562 }
563 }
564 }
565
SetAttrValue(const TensorProto & value,AttrValue * out)566 void SetAttrValue(const TensorProto& value, AttrValue* out) {
567 *out->mutable_tensor() = value;
568 }
569
SetAttrValue(const gtl::ArraySlice<TensorProto> value,AttrValue * out)570 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
571 out->mutable_list()->Clear(); // Create list() even if value empty.
572 for (const auto& v : value) {
573 *out->mutable_list()->add_tensor() = v;
574 }
575 }
576
SetAttrValue(const NameAttrList & value,AttrValue * out)577 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
578 *out->mutable_func() = value;
579 }
580
SetAttrValue(gtl::ArraySlice<NameAttrList> value,AttrValue * out)581 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
582 out->mutable_list()->Clear(); // Create list() even if value empty.
583 for (const auto& v : value) {
584 *out->mutable_list()->add_func() = v;
585 }
586 }
587
AreAttrValuesEqual(const AttrValue & a,const AttrValue & b)588 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
589 return AreAttrValuesEqual(a, b, AreTensorProtosEqual);
590 }
591
AttrValueHash(const AttrValue & a)592 uint64 AttrValueHash(const AttrValue& a) {
593 return AttrValueHash(a, TensorProtoHash);
594 }
595
FastAreAttrValuesEqual(const AttrValue & a,const AttrValue & b)596 bool FastAreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
597 return AreAttrValuesEqual(a, b, FastAreTensorProtosEqual);
598 }
599
FastAttrValueHash(const AttrValue & a)600 uint64 FastAttrValueHash(const AttrValue& a) {
601 return AttrValueHash(a, FastTensorProtoHash);
602 }
603
HasPlaceHolder(const AttrValue & val)604 bool HasPlaceHolder(const AttrValue& val) {
605 switch (val.value_case()) {
606 case AttrValue::kList: {
607 for (const NameAttrList& func : val.list().func()) {
608 for (const auto& p : func.attr()) {
609 if (HasPlaceHolder(p.second)) {
610 return true;
611 }
612 }
613 }
614 break;
615 }
616 case AttrValue::kFunc:
617 for (const auto& p : val.func().attr()) {
618 if (HasPlaceHolder(p.second)) {
619 return true;
620 }
621 }
622 break;
623 case AttrValue::kPlaceholder:
624 return true;
625 default:
626 break;
627 }
628 return false;
629 }
630
SubstitutePlaceholders(const SubstituteFunc & substitute,AttrValue * value)631 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
632 AttrValue* value) {
633 switch (value->value_case()) {
634 case AttrValue::kList: {
635 for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
636 for (auto& p : *func.mutable_attr()) {
637 if (!SubstitutePlaceholders(substitute, &p.second)) {
638 return false;
639 }
640 }
641 }
642 break;
643 }
644 case AttrValue::kFunc:
645 for (auto& p : *(value->mutable_func()->mutable_attr())) {
646 if (!SubstitutePlaceholders(substitute, &p.second)) {
647 return false;
648 }
649 }
650 break;
651 case AttrValue::kPlaceholder:
652 return substitute(value->placeholder(), value);
653 case AttrValue::VALUE_NOT_SET:
654 return false;
655 default:
656 break;
657 }
658 return true;
659 }
660
661 } // namespace tensorflow
662