1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_tensor_data.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace {
30
TEST(TensorUtil,DeepCopy0d)31 TEST(TensorUtil, DeepCopy0d) {
32 Tensor x(DT_FLOAT, TensorShape({}));
33 x.scalar<float>()() = 10.0;
34
35 // Make y a deep copy of x and then change it.
36 Tensor y = tensor::DeepCopy(x);
37 y.scalar<float>()() = 20.0;
38
39 // x doesn't change
40 EXPECT_EQ(10.0, x.scalar<float>()());
41
42 // Change x.
43 x.scalar<float>()() = 30.0;
44
45 // Y doesn't change.
46 EXPECT_EQ(20.0, y.scalar<float>()());
47
48 Tensor z = tensor::DeepCopy(y);
49
50 // Change y.
51 y.scalar<float>()() = 40.0;
52
53 // The final states should all be different.
54 EXPECT_EQ(20.0, z.scalar<float>()());
55 EXPECT_EQ(30.0, x.scalar<float>()());
56 EXPECT_EQ(40.0, y.scalar<float>()());
57
58 // Should have the same shape and type.
59 EXPECT_EQ(TensorShape({}), x.shape());
60 EXPECT_EQ(TensorShape({}), y.shape());
61 EXPECT_EQ(TensorShape({}), z.shape());
62
63 EXPECT_EQ(DT_FLOAT, x.dtype());
64 EXPECT_EQ(DT_FLOAT, y.dtype());
65 EXPECT_EQ(DT_FLOAT, z.dtype());
66 }
67
TEST(TensorUtil,DeepCopyZeroElements)68 TEST(TensorUtil, DeepCopyZeroElements) {
69 Tensor x;
70 Tensor y = tensor::DeepCopy(x);
71 EXPECT_EQ(TensorShape({0}), y.shape());
72 EXPECT_EQ(DT_FLOAT, y.dtype());
73 EXPECT_EQ(0, y.NumElements());
74 }
75
TEST(TensorUtil,DeepCopy)76 TEST(TensorUtil, DeepCopy) {
77 Tensor x(DT_FLOAT, TensorShape({1}));
78 x.flat<float>()(0) = 10.0;
79
80 // Make y a deep copy of x and then change it.
81 Tensor y = tensor::DeepCopy(x);
82 y.flat<float>()(0) = 20.0;
83
84 // x doesn't change
85 EXPECT_EQ(10.0, x.flat<float>()(0));
86
87 // Change x.
88 x.flat<float>()(0) = 30.0;
89
90 // Y doesn't change.
91 EXPECT_EQ(20.0, y.flat<float>()(0));
92
93 Tensor z = tensor::DeepCopy(y);
94
95 // Change y.
96 y.flat<float>()(0) = 40.0;
97
98 // The final states should all be different.
99 EXPECT_EQ(20.0, z.flat<float>()(0));
100 EXPECT_EQ(30.0, x.flat<float>()(0));
101 EXPECT_EQ(40.0, y.flat<float>()(0));
102
103 // Should have the same shape and type.
104 EXPECT_EQ(TensorShape({1}), x.shape());
105 EXPECT_EQ(TensorShape({1}), y.shape());
106 EXPECT_EQ(TensorShape({1}), z.shape());
107
108 EXPECT_EQ(DT_FLOAT, x.dtype());
109 EXPECT_EQ(DT_FLOAT, y.dtype());
110 EXPECT_EQ(DT_FLOAT, z.dtype());
111
112 // Test string deep copy
113 Tensor str1(DT_STRING, TensorShape({2}));
114 str1.flat<string>()(0) = "foo1";
115 str1.flat<string>()(1) = "foo2";
116 Tensor str2 = tensor::DeepCopy(str1);
117 str2.flat<string>()(0) = "bar1";
118 str2.flat<string>()(1) = "bar2";
119 EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0));
120 }
121
TEST(TensorUtil,DeepCopySlice)122 TEST(TensorUtil, DeepCopySlice) {
123 Tensor x(DT_INT32, TensorShape({10}));
124 x.flat<int32>().setConstant(1);
125
126 // Slice 'x' -- y still refers to the same buffer.
127 Tensor y = x.Slice(2, 6);
128
129 // Do a deep copy of y, which is a slice.
130 Tensor z = tensor::DeepCopy(y);
131
132 // Set x to be different.
133 x.flat<int32>().setConstant(2);
134
135 EXPECT_EQ(TensorShape({10}), x.shape());
136 EXPECT_EQ(TensorShape({4}), y.shape());
137 EXPECT_EQ(TensorShape({4}), z.shape());
138 EXPECT_EQ(DT_INT32, x.dtype());
139 EXPECT_EQ(DT_INT32, y.dtype());
140 EXPECT_EQ(DT_INT32, z.dtype());
141
142 // x and y should now all be '2', but z should be '1'.
143 for (int i = 0; i < 10; ++i) {
144 EXPECT_EQ(2, x.flat<int32>()(i));
145 }
146 for (int i = 0; i < 4; ++i) {
147 EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
148 EXPECT_EQ(1, z.flat<int32>()(i));
149 }
150 }
151
TEST(TensorUtil,DeepCopySliceString)152 TEST(TensorUtil, DeepCopySliceString) {
153 Tensor x(DT_STRING, TensorShape({10}));
154 x.flat<string>().setConstant("hello");
155
156 // Slice 'x' -- y still refers to the same buffer.
157 Tensor y = x.Slice(3, 7);
158
159 // Do a deep copy of y, which is a slice.
160 Tensor z = tensor::DeepCopy(y);
161
162 // Set x to be different.
163 x.flat<string>().setConstant("goodbye");
164
165 EXPECT_EQ(TensorShape({10}), x.shape());
166 EXPECT_EQ(TensorShape({4}), y.shape());
167 EXPECT_EQ(TensorShape({4}), z.shape());
168 EXPECT_EQ(DT_STRING, x.dtype());
169 EXPECT_EQ(DT_STRING, y.dtype());
170 EXPECT_EQ(DT_STRING, z.dtype());
171
172 // x and y should now all be 'goodbye', but z should be 'hello'.
173 for (int i = 0; i < 10; ++i) {
174 EXPECT_EQ("goodbye", x.flat<string>()(i));
175 }
176 for (int i = 0; i < 4; ++i) {
177 EXPECT_EQ("goodbye", y.unaligned_flat<string>()(i));
178 EXPECT_EQ("hello", z.flat<string>()(i));
179 }
180 }
181
TEST(TensorUtil,DeepCopySliceVariant)182 TEST(TensorUtil, DeepCopySliceVariant) {
183 Tensor x(DT_VARIANT, TensorShape({10}));
184 x.flat<Variant>().setConstant(Tensor(42.0f));
185
186 // Slice 'x' -- y still refers to the same buffer.
187 Tensor y = x.Slice(3, 7);
188
189 // Do a deep copy of y, which is a slice.
190 Tensor z = tensor::DeepCopy(y);
191
192 // Set x to be different.
193 x.flat<Variant>().setConstant(Tensor("foo"));
194
195 EXPECT_EQ(TensorShape({10}), x.shape());
196 EXPECT_EQ(TensorShape({4}), y.shape());
197 EXPECT_EQ(TensorShape({4}), z.shape());
198 EXPECT_EQ(DT_VARIANT, x.dtype());
199 EXPECT_EQ(DT_VARIANT, y.dtype());
200 EXPECT_EQ(DT_VARIANT, z.dtype());
201
202 // Each element of x and y should now be a DT_STRING Tensor containing "foo",
203 // but each element of z should be a DT_FLOAT tensor containing 42.0.
204 for (int i = 0; i < 10; ++i) {
205 EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<string>()());
206 }
207 for (int i = 0; i < 4; ++i) {
208 EXPECT_EQ("foo",
209 y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<string>()());
210 EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
211 }
212 }
213
TEST(TensorUtil,Concat)214 TEST(TensorUtil, Concat) {
215 std::vector<int64> sizes = {1, 4, 5};
216 std::vector<Tensor> to_concat;
217 int64 total_size = 0;
218 int offset = 0;
219 for (size_t entry = 0; entry < sizes.size(); ++entry) {
220 const int64 size = sizes[entry];
221 Tensor tensor(DT_INT32, TensorShape({size, 2}));
222 for (int i = offset; i < offset + size; ++i) {
223 for (int j = 0; j < 2; ++j) {
224 tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
225 }
226 }
227 to_concat.push_back(tensor);
228 total_size += size;
229 offset += size;
230 }
231
232 Tensor concated;
233 TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
234 ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
235 for (int i = 0; i < total_size; ++i) {
236 for (int j = 0; j < 2; ++j) {
237 EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
238 }
239 }
240 }
241
TEST(TensorUtil,Split)242 TEST(TensorUtil, Split) {
243 Tensor to_split(DT_INT64, TensorShape({10, 2}));
244 for (int i = 0; i < 10; ++i) {
245 for (int j = 0; j < 2; ++j) {
246 to_split.matrix<int64>()(i, j) = 2 * i + j;
247 }
248 }
249
250 std::vector<int64> sizes = {1, 4, 5};
251 std::vector<Tensor> splits;
252 TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
253 ASSERT_EQ(sizes.size(), splits.size());
254
255 int offset = 0;
256 for (size_t entry = 0; entry < splits.size(); ++entry) {
257 const int64 size = sizes[entry];
258 const Tensor& split = splits[entry];
259
260 ASSERT_EQ(TensorShape({size, 2}), split.shape());
261 for (int i = offset; i < offset + size; ++i) {
262 for (int j = 0; j < 2; ++j) {
263 EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
264 }
265 }
266
267 offset += size;
268 }
269 }
270
TEST(TensorUtil,ConcatSplitStrings)271 TEST(TensorUtil, ConcatSplitStrings) {
272 Tensor x(DT_STRING, TensorShape({4, 3}));
273 for (int i = 0; i < 4 * 3; ++i) {
274 x.flat<string>()(i) = strings::StrCat("foo_", i);
275 }
276
277 std::vector<Tensor> split;
278 TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
279 Tensor x_round_tripped;
280 TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
281 ASSERT_EQ(x.shape(), x_round_tripped.shape());
282 for (int i = 0; i < 4 * 3; ++i) {
283 EXPECT_EQ(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
284 }
285
286 // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
287 for (int i = 0; i < 4 * 3; ++i) {
288 x_round_tripped.flat<string>()(i) = strings::StrCat("bar_", i);
289 }
290 for (int i = 0; i < 4 * 3; ++i) {
291 EXPECT_NE(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
292 }
293 }
294
TEST(TensorProtoUtil,CreatesStringTensorProto)295 TEST(TensorProtoUtil, CreatesStringTensorProto) {
296 std::vector<string> values{"a", "b", "c"};
297 std::vector<size_t> shape{1, 3};
298
299 auto proto = tensor::CreateTensorProto(values, shape);
300
301 EXPECT_EQ(proto.DebugString(),
302 "dtype: DT_STRING\n"
303 "tensor_shape {\n"
304 " dim {\n"
305 " size: 1\n"
306 " }\n"
307 " dim {\n"
308 " size: 3\n"
309 " }\n"
310 "}\n"
311 "string_val: \"a\"\n"
312 "string_val: \"b\"\n"
313 "string_val: \"c\"\n");
314 }
315
TEST(TensorProtoUtil,CreatesInt32TensorProto)316 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
317 std::vector<int32> values{1, 2};
318 std::vector<size_t> shape{2};
319
320 auto proto = tensor::CreateTensorProto(values, shape);
321
322 EXPECT_EQ(proto.DebugString(),
323 "dtype: DT_INT32\n"
324 "tensor_shape {\n"
325 " dim {\n"
326 " size: 2\n"
327 " }\n"
328 "}\n"
329 "int_val: 1\n"
330 "int_val: 2\n");
331 }
332
TEST(TensorProtoUtil,CreatesInt64TensorProto)333 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
334 std::vector<int64> values{1, 2};
335 std::vector<size_t> shape{2};
336
337 auto proto = tensor::CreateTensorProto(values, shape);
338
339 EXPECT_EQ(proto.DebugString(),
340 "dtype: DT_INT64\n"
341 "tensor_shape {\n"
342 " dim {\n"
343 " size: 2\n"
344 " }\n"
345 "}\n"
346 "int64_val: 1\n"
347 "int64_val: 2\n");
348 }
349
TEST(TensorProtoUtil,CreatesUInt32TensorProto)350 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
351 std::vector<uint32> values{1, 2};
352 std::vector<size_t> shape{2};
353
354 auto proto = tensor::CreateTensorProto(values, shape);
355
356 EXPECT_EQ(proto.DebugString(),
357 "dtype: DT_UINT32\n"
358 "tensor_shape {\n"
359 " dim {\n"
360 " size: 2\n"
361 " }\n"
362 "}\n"
363 "uint32_val: 1\n"
364 "uint32_val: 2\n");
365 }
366
TEST(TensorProtoUtil,CreatesUInt64TensorProto)367 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
368 std::vector<uint64> values{1, 2};
369 std::vector<size_t> shape{2};
370
371 auto proto = tensor::CreateTensorProto(values, shape);
372
373 EXPECT_EQ(proto.DebugString(),
374 "dtype: DT_UINT64\n"
375 "tensor_shape {\n"
376 " dim {\n"
377 " size: 2\n"
378 " }\n"
379 "}\n"
380 "uint64_val: 1\n"
381 "uint64_val: 2\n");
382 }
383
TEST(TensorProtoUtil,CreatesFloatTensorProto)384 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
385 std::vector<float> values{1.1, 2.2};
386 std::vector<size_t> shape{2};
387
388 auto proto = tensor::CreateTensorProto(values, shape);
389
390 EXPECT_EQ(proto.DebugString(),
391 "dtype: DT_FLOAT\n"
392 "tensor_shape {\n"
393 " dim {\n"
394 " size: 2\n"
395 " }\n"
396 "}\n"
397 "float_val: 1.1\n"
398 "float_val: 2.2\n");
399 }
400
TEST(TensorProtoUtil,CreatesDoubleTensorProto)401 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
402 std::vector<double> values{1.1, 2.2};
403 std::vector<size_t> shape{2};
404
405 auto proto = tensor::CreateTensorProto(values, shape);
406
407 EXPECT_EQ(proto.DebugString(),
408 "dtype: DT_DOUBLE\n"
409 "tensor_shape {\n"
410 " dim {\n"
411 " size: 2\n"
412 " }\n"
413 "}\n"
414 "double_val: 1.1\n"
415 "double_val: 2.2\n");
416 }
417
TEST(TensorProtoUtil,CreatesBoolTensorProto)418 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
419 std::vector<bool> values{true, false};
420 std::vector<size_t> shape{2};
421
422 auto proto = tensor::CreateTensorProto(values, shape);
423
424 EXPECT_EQ(proto.DebugString(),
425 "dtype: DT_BOOL\n"
426 "tensor_shape {\n"
427 " dim {\n"
428 " size: 2\n"
429 " }\n"
430 "}\n"
431 "bool_val: true\n"
432 "bool_val: false\n");
433 }
434
TEST(TensorProtoUtil,CompressTensorProtoInPlaceTooSmall)435 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
436 const int kLength = 63;
437 TensorProto tensor_proto =
438 tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
439 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
440 tensor_proto =
441 tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
442 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
443 tensor_proto =
444 tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
445 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
446 tensor_proto =
447 tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
448 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
449 tensor_proto =
450 tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
451 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
452 tensor_proto = tensor::CreateTensorProto(
453 std::vector<std::complex<float>>(kLength), {kLength});
454 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
455 }
456
TEST(TensorProtoUtil,CompressTensorProtoInPlaceAllEqual)457 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
458 const int kLength = 64;
459 TensorProto tensor_proto =
460 tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
461 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
462 EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
463 1);
464
465 tensor_proto =
466 tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
467 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
468 EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
469 1);
470
471 tensor_proto =
472 tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
473 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
474 EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
475 1);
476 tensor_proto =
477 tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
478 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
479 EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
480 1);
481
482 tensor_proto =
483 tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
484 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
485 EXPECT_EQ(
486 tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
487 1);
488
489 tensor_proto = tensor::CreateTensorProto(
490 std::vector<std::complex<float>>(kLength), {kLength});
491 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
492 EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
493 tensor_proto),
494 1);
495 }
496
497 template <typename T>
VectorWithConstantTail(int size,int tail_length)498 std::vector<T> VectorWithConstantTail(int size, int tail_length) {
499 CHECK_LE(tail_length, size);
500 std::vector<T> v(size, T(0));
501 for (int i = 0; i < size - tail_length; ++i) {
502 v[i] = T(i + 1);
503 }
504 return v;
505 }
506
507 template <typename T>
CreateAsProtoTensorContent(int size,int tail_length)508 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
509 auto values = VectorWithConstantTail<T>(size, tail_length);
510 Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
511 std::copy(values.begin(), values.end(), tensor.flat<T>().data());
512 TensorProto tensor_proto;
513 tensor.AsProtoTensorContent(&tensor_proto);
514 return tensor_proto;
515 }
516
517 template <typename T>
CreateAsProtoField(int size,int tail_length)518 TensorProto CreateAsProtoField(int size, int tail_length) {
519 auto values = VectorWithConstantTail<T>(size, tail_length);
520 Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
521 std::copy(values.begin(), values.end(), tensor.flat<T>().data());
522 TensorProto tensor_proto;
523 tensor.AsProtoField(&tensor_proto);
524 return tensor_proto;
525 }
526
527 template <typename T>
CompareTensorValues(const TensorProto & x,const TensorProto & y)528 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
529 Tensor x_t;
530 EXPECT_TRUE(x_t.FromProto(x));
531 Tensor y_t;
532 EXPECT_TRUE(y_t.FromProto(y));
533 test::ExpectTensorEqual<T>(x_t, y_t);
534 }
535
536 template <typename T>
ConstantTailTest(int64 length,int64 tail_length,bool as_field)537 void ConstantTailTest(int64 length, int64 tail_length, bool as_field) {
538 using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
539 using FieldType = typename TensorProtoHelper::FieldType;
540 const float kMinCompressionRatio = 2.0;
541 const int64 kMinSize = 64;
542 TensorProto tensor_proto =
543 as_field ? CreateAsProtoField<T>(length, tail_length)
544 : CreateAsProtoTensorContent<T>(length, tail_length);
545 TensorProto original_tensor_proto = tensor_proto;
546 int64 original_size =
547 length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
548 : sizeof(T));
549 int64 size_as_tensor_content = length * sizeof(T);
550 int64 size_as_field = std::min(length, (length - tail_length + 1)) *
551 (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
552 bool will_compress = std::min(size_as_tensor_content, size_as_field) <=
553 static_cast<int64>(original_size / kMinCompressionRatio);
554
555 EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
556 &tensor_proto),
557 will_compress);
558 if (will_compress) {
559 if (size_as_tensor_content < size_as_field) {
560 EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
561 EXPECT_FALSE(tensor_proto.tensor_content().empty());
562 } else {
563 EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
564 (length - tail_length + 1));
565 EXPECT_TRUE(tensor_proto.tensor_content().empty());
566 }
567 }
568 CompareTensorValues<T>(tensor_proto, original_tensor_proto);
569 }
570
TEST(TensorProtoUtil,CompressTensorProtoConstantTail)571 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
572 const int kLength = 64;
573 for (bool as_field : {true, false}) {
574 for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
575 ConstantTailTest<float>(kLength, tail_length, as_field);
576 ConstantTailTest<double>(kLength, tail_length, as_field);
577 ConstantTailTest<complex64>(kLength, tail_length, as_field);
578 ConstantTailTest<complex128>(kLength, tail_length, as_field);
579 ConstantTailTest<int32>(kLength, tail_length, as_field);
580 ConstantTailTest<uint32>(kLength, tail_length, as_field);
581 ConstantTailTest<int64>(kLength, tail_length, as_field);
582 ConstantTailTest<uint64>(kLength, tail_length, as_field);
583 ConstantTailTest<int8>(kLength, tail_length, as_field);
584 ConstantTailTest<uint8>(kLength, tail_length, as_field);
585 ConstantTailTest<int16>(kLength, tail_length, as_field);
586 ConstantTailTest<uint16>(kLength, tail_length, as_field);
587 ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
588 ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
589 }
590 }
591 }
592
593 } // namespace
594 } // namespace tensorflow
595