1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <vector>
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_tensor_data.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27
28 namespace tensorflow {
29 namespace {
30
TEST(TensorUtil,DeepCopy0d)31 TEST(TensorUtil, DeepCopy0d) {
32 Tensor x(DT_FLOAT, TensorShape({}));
33 x.scalar<float>()() = 10.0;
34
35 // Make y a deep copy of x and then change it.
36 Tensor y = tensor::DeepCopy(x);
37 y.scalar<float>()() = 20.0;
38
39 // x doesn't change
40 EXPECT_EQ(10.0, x.scalar<float>()());
41
42 // Change x.
43 x.scalar<float>()() = 30.0;
44
45 // Y doesn't change.
46 EXPECT_EQ(20.0, y.scalar<float>()());
47
48 Tensor z = tensor::DeepCopy(y);
49
50 // Change y.
51 y.scalar<float>()() = 40.0;
52
53 // The final states should all be different.
54 EXPECT_EQ(20.0, z.scalar<float>()());
55 EXPECT_EQ(30.0, x.scalar<float>()());
56 EXPECT_EQ(40.0, y.scalar<float>()());
57
58 // Should have the same shape and type.
59 EXPECT_EQ(TensorShape({}), x.shape());
60 EXPECT_EQ(TensorShape({}), y.shape());
61 EXPECT_EQ(TensorShape({}), z.shape());
62
63 EXPECT_EQ(DT_FLOAT, x.dtype());
64 EXPECT_EQ(DT_FLOAT, y.dtype());
65 EXPECT_EQ(DT_FLOAT, z.dtype());
66 }
67
TEST(TensorUtil,DeepCopyZeroElements)68 TEST(TensorUtil, DeepCopyZeroElements) {
69 Tensor x;
70 Tensor y = tensor::DeepCopy(x);
71 EXPECT_EQ(TensorShape({0}), y.shape());
72 EXPECT_EQ(DT_FLOAT, y.dtype());
73 EXPECT_EQ(0, y.NumElements());
74 }
75
TEST(TensorUtil,DeepCopy)76 TEST(TensorUtil, DeepCopy) {
77 Tensor x(DT_FLOAT, TensorShape({1}));
78 x.flat<float>()(0) = 10.0;
79
80 // Make y a deep copy of x and then change it.
81 Tensor y = tensor::DeepCopy(x);
82 y.flat<float>()(0) = 20.0;
83
84 // x doesn't change
85 EXPECT_EQ(10.0, x.flat<float>()(0));
86
87 // Change x.
88 x.flat<float>()(0) = 30.0;
89
90 // Y doesn't change.
91 EXPECT_EQ(20.0, y.flat<float>()(0));
92
93 Tensor z = tensor::DeepCopy(y);
94
95 // Change y.
96 y.flat<float>()(0) = 40.0;
97
98 // The final states should all be different.
99 EXPECT_EQ(20.0, z.flat<float>()(0));
100 EXPECT_EQ(30.0, x.flat<float>()(0));
101 EXPECT_EQ(40.0, y.flat<float>()(0));
102
103 // Should have the same shape and type.
104 EXPECT_EQ(TensorShape({1}), x.shape());
105 EXPECT_EQ(TensorShape({1}), y.shape());
106 EXPECT_EQ(TensorShape({1}), z.shape());
107
108 EXPECT_EQ(DT_FLOAT, x.dtype());
109 EXPECT_EQ(DT_FLOAT, y.dtype());
110 EXPECT_EQ(DT_FLOAT, z.dtype());
111
112 // Test string deep copy
113 Tensor str1(DT_STRING, TensorShape({2}));
114 str1.flat<tstring>()(0) = "foo1";
115 str1.flat<tstring>()(1) = "foo2";
116 Tensor str2 = tensor::DeepCopy(str1);
117 str2.flat<tstring>()(0) = "bar1";
118 str2.flat<tstring>()(1) = "bar2";
119 EXPECT_NE(str2.flat<tstring>()(0), str1.flat<tstring>()(0));
120 }
121
TEST(TensorUtil,DeepCopySlice)122 TEST(TensorUtil, DeepCopySlice) {
123 Tensor x(DT_INT32, TensorShape({10}));
124 x.flat<int32>().setConstant(1);
125
126 // Slice 'x' -- y still refers to the same buffer.
127 Tensor y = x.Slice(2, 6);
128
129 // Do a deep copy of y, which is a slice.
130 Tensor z = tensor::DeepCopy(y);
131
132 // Set x to be different.
133 x.flat<int32>().setConstant(2);
134
135 EXPECT_EQ(TensorShape({10}), x.shape());
136 EXPECT_EQ(TensorShape({4}), y.shape());
137 EXPECT_EQ(TensorShape({4}), z.shape());
138 EXPECT_EQ(DT_INT32, x.dtype());
139 EXPECT_EQ(DT_INT32, y.dtype());
140 EXPECT_EQ(DT_INT32, z.dtype());
141
142 // x and y should now all be '2', but z should be '1'.
143 for (int i = 0; i < 10; ++i) {
144 EXPECT_EQ(2, x.flat<int32>()(i));
145 }
146 for (int i = 0; i < 4; ++i) {
147 EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
148 EXPECT_EQ(1, z.flat<int32>()(i));
149 }
150 }
151
TEST(TensorUtil,DeepCopySliceString)152 TEST(TensorUtil, DeepCopySliceString) {
153 Tensor x(DT_STRING, TensorShape({10}));
154 x.flat<tstring>().setConstant("hello");
155
156 // Slice 'x' -- y still refers to the same buffer.
157 Tensor y = x.Slice(3, 7);
158
159 // Do a deep copy of y, which is a slice.
160 Tensor z = tensor::DeepCopy(y);
161
162 // Set x to be different.
163 x.flat<tstring>().setConstant("goodbye");
164
165 EXPECT_EQ(TensorShape({10}), x.shape());
166 EXPECT_EQ(TensorShape({4}), y.shape());
167 EXPECT_EQ(TensorShape({4}), z.shape());
168 EXPECT_EQ(DT_STRING, x.dtype());
169 EXPECT_EQ(DT_STRING, y.dtype());
170 EXPECT_EQ(DT_STRING, z.dtype());
171
172 // x and y should now all be 'goodbye', but z should be 'hello'.
173 for (int i = 0; i < 10; ++i) {
174 EXPECT_EQ("goodbye", x.flat<tstring>()(i));
175 }
176 for (int i = 0; i < 4; ++i) {
177 EXPECT_EQ("goodbye", y.unaligned_flat<tstring>()(i));
178 EXPECT_EQ("hello", z.flat<tstring>()(i));
179 }
180 }
181
TEST(TensorUtil,DeepCopySliceVariant)182 TEST(TensorUtil, DeepCopySliceVariant) {
183 Tensor x(DT_VARIANT, TensorShape({10}));
184 x.flat<Variant>().setConstant(Tensor(42.0f));
185
186 // Slice 'x' -- y still refers to the same buffer.
187 Tensor y = x.Slice(3, 7);
188
189 // Do a deep copy of y, which is a slice.
190 Tensor z = tensor::DeepCopy(y);
191
192 // Set x to be different.
193 x.flat<Variant>().setConstant(Tensor("foo"));
194
195 EXPECT_EQ(TensorShape({10}), x.shape());
196 EXPECT_EQ(TensorShape({4}), y.shape());
197 EXPECT_EQ(TensorShape({4}), z.shape());
198 EXPECT_EQ(DT_VARIANT, x.dtype());
199 EXPECT_EQ(DT_VARIANT, y.dtype());
200 EXPECT_EQ(DT_VARIANT, z.dtype());
201
202 // Each element of x and y should now be a DT_STRING Tensor containing "foo",
203 // but each element of z should be a DT_FLOAT tensor containing 42.0.
204 for (int i = 0; i < 10; ++i) {
205 EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
206 }
207 for (int i = 0; i < 4; ++i) {
208 EXPECT_EQ(
209 "foo",
210 y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
211 EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
212 }
213 }
214
TEST(TensorUtil,Concat)215 TEST(TensorUtil, Concat) {
216 std::vector<int64> sizes = {1, 4, 5};
217 std::vector<Tensor> to_concat;
218 int64 total_size = 0;
219 int offset = 0;
220 for (size_t entry = 0; entry < sizes.size(); ++entry) {
221 const int64 size = sizes[entry];
222 Tensor tensor(DT_INT32, TensorShape({size, 2}));
223 for (int i = offset; i < offset + size; ++i) {
224 for (int j = 0; j < 2; ++j) {
225 tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
226 }
227 }
228 to_concat.push_back(tensor);
229 total_size += size;
230 offset += size;
231 }
232
233 Tensor concated;
234 TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
235 ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
236 for (int i = 0; i < total_size; ++i) {
237 for (int j = 0; j < 2; ++j) {
238 EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
239 }
240 }
241 }
242
TEST(TensorUtil,Split)243 TEST(TensorUtil, Split) {
244 Tensor to_split(DT_INT64, TensorShape({10, 2}));
245 for (int i = 0; i < 10; ++i) {
246 for (int j = 0; j < 2; ++j) {
247 to_split.matrix<int64>()(i, j) = 2 * i + j;
248 }
249 }
250
251 std::vector<int64> sizes = {1, 4, 5};
252 std::vector<Tensor> splits;
253 TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
254 ASSERT_EQ(sizes.size(), splits.size());
255
256 int offset = 0;
257 for (size_t entry = 0; entry < splits.size(); ++entry) {
258 const int64 size = sizes[entry];
259 const Tensor& split = splits[entry];
260
261 ASSERT_EQ(TensorShape({size, 2}), split.shape());
262 for (int i = offset; i < offset + size; ++i) {
263 for (int j = 0; j < 2; ++j) {
264 EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
265 }
266 }
267
268 offset += size;
269 }
270 }
271
TEST(TensorUtil,ConcatSplitStrings)272 TEST(TensorUtil, ConcatSplitStrings) {
273 Tensor x(DT_STRING, TensorShape({4, 3}));
274 for (int i = 0; i < 4 * 3; ++i) {
275 x.flat<tstring>()(i) = strings::StrCat("foo_", i);
276 }
277
278 std::vector<Tensor> split;
279 TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
280 Tensor x_round_tripped;
281 TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
282 ASSERT_EQ(x.shape(), x_round_tripped.shape());
283 for (int i = 0; i < 4 * 3; ++i) {
284 EXPECT_EQ(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
285 }
286
287 // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
288 for (int i = 0; i < 4 * 3; ++i) {
289 x_round_tripped.flat<tstring>()(i) = strings::StrCat("bar_", i);
290 }
291 for (int i = 0; i < 4 * 3; ++i) {
292 EXPECT_NE(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
293 }
294 }
295
TEST(TensorProtoUtil,CreatesStringTensorProto)296 TEST(TensorProtoUtil, CreatesStringTensorProto) {
297 std::vector<string> values{"a", "b", "c"};
298 std::vector<size_t> shape{1, 3};
299
300 auto proto = tensor::CreateTensorProto(values, shape);
301
302 EXPECT_EQ(proto.DebugString(),
303 "dtype: DT_STRING\n"
304 "tensor_shape {\n"
305 " dim {\n"
306 " size: 1\n"
307 " }\n"
308 " dim {\n"
309 " size: 3\n"
310 " }\n"
311 "}\n"
312 "string_val: \"a\"\n"
313 "string_val: \"b\"\n"
314 "string_val: \"c\"\n");
315 }
316
TEST(TensorProtoUtil,CreatesInt32TensorProto)317 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
318 std::vector<int32> values{1, 2};
319 std::vector<size_t> shape{2};
320
321 auto proto = tensor::CreateTensorProto(values, shape);
322
323 EXPECT_EQ(proto.DebugString(),
324 "dtype: DT_INT32\n"
325 "tensor_shape {\n"
326 " dim {\n"
327 " size: 2\n"
328 " }\n"
329 "}\n"
330 "int_val: 1\n"
331 "int_val: 2\n");
332 }
333
TEST(TensorProtoUtil,CreatesInt64TensorProto)334 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
335 std::vector<int64> values{1, 2};
336 std::vector<size_t> shape{2};
337
338 auto proto = tensor::CreateTensorProto(values, shape);
339
340 EXPECT_EQ(proto.DebugString(),
341 "dtype: DT_INT64\n"
342 "tensor_shape {\n"
343 " dim {\n"
344 " size: 2\n"
345 " }\n"
346 "}\n"
347 "int64_val: 1\n"
348 "int64_val: 2\n");
349 }
350
TEST(TensorProtoUtil,CreatesUInt32TensorProto)351 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
352 std::vector<uint32> values{1, 2};
353 std::vector<size_t> shape{2};
354
355 auto proto = tensor::CreateTensorProto(values, shape);
356
357 EXPECT_EQ(proto.DebugString(),
358 "dtype: DT_UINT32\n"
359 "tensor_shape {\n"
360 " dim {\n"
361 " size: 2\n"
362 " }\n"
363 "}\n"
364 "uint32_val: 1\n"
365 "uint32_val: 2\n");
366 }
367
TEST(TensorProtoUtil,CreatesUInt64TensorProto)368 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
369 std::vector<uint64> values{1, 2};
370 std::vector<size_t> shape{2};
371
372 auto proto = tensor::CreateTensorProto(values, shape);
373
374 EXPECT_EQ(proto.DebugString(),
375 "dtype: DT_UINT64\n"
376 "tensor_shape {\n"
377 " dim {\n"
378 " size: 2\n"
379 " }\n"
380 "}\n"
381 "uint64_val: 1\n"
382 "uint64_val: 2\n");
383 }
384
TEST(TensorProtoUtil,CreatesFloatTensorProto)385 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
386 std::vector<float> values{1.1, 2.2};
387 std::vector<size_t> shape{2};
388
389 auto proto = tensor::CreateTensorProto(values, shape);
390
391 EXPECT_EQ(proto.DebugString(),
392 "dtype: DT_FLOAT\n"
393 "tensor_shape {\n"
394 " dim {\n"
395 " size: 2\n"
396 " }\n"
397 "}\n"
398 "float_val: 1.1\n"
399 "float_val: 2.2\n");
400 }
401
TEST(TensorProtoUtil,CreatesDoubleTensorProto)402 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
403 std::vector<double> values{1.1, 2.2};
404 std::vector<size_t> shape{2};
405
406 auto proto = tensor::CreateTensorProto(values, shape);
407
408 EXPECT_EQ(proto.DebugString(),
409 "dtype: DT_DOUBLE\n"
410 "tensor_shape {\n"
411 " dim {\n"
412 " size: 2\n"
413 " }\n"
414 "}\n"
415 "double_val: 1.1\n"
416 "double_val: 2.2\n");
417 }
418
TEST(TensorProtoUtil,CreatesBoolTensorProto)419 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
420 std::vector<bool> values{true, false};
421 std::vector<size_t> shape{2};
422
423 auto proto = tensor::CreateTensorProto(values, shape);
424
425 EXPECT_EQ(proto.DebugString(),
426 "dtype: DT_BOOL\n"
427 "tensor_shape {\n"
428 " dim {\n"
429 " size: 2\n"
430 " }\n"
431 "}\n"
432 "bool_val: true\n"
433 "bool_val: false\n");
434 }
435
TEST(TensorProtoUtil,CompressTensorProtoInPlaceTooSmall)436 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
437 const int kLength = 63;
438 TensorProto tensor_proto =
439 tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
440 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
441 tensor_proto =
442 tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
443 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
444 tensor_proto =
445 tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
446 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
447 tensor_proto =
448 tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
449 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
450 tensor_proto =
451 tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
452 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
453 tensor_proto = tensor::CreateTensorProto(
454 std::vector<std::complex<float>>(kLength), {kLength});
455 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
456 }
457
TEST(TensorProtoUtil,CompressTensorProtoInPlaceAllEqual)458 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
459 const int kLength = 64;
460 TensorProto tensor_proto =
461 tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
462 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
463 EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
464 1);
465
466 tensor_proto =
467 tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
468 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
469 EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
470 1);
471
472 tensor_proto =
473 tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
474 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
475 EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
476 1);
477 tensor_proto =
478 tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
479 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
480 EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
481 1);
482
483 tensor_proto =
484 tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
485 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
486 EXPECT_EQ(
487 tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
488 1);
489
490 tensor_proto = tensor::CreateTensorProto(
491 std::vector<std::complex<float>>(kLength), {kLength});
492 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
493 EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
494 tensor_proto),
495 1);
496 }
497
498 template <typename T>
VectorWithConstantTail(int size,int tail_length,std::vector<T> * v)499 void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
500 CHECK_LE(tail_length, size);
501 v->clear();
502 for (int i = 0; i < size; ++i) {
503 T vi = (i >= size - tail_length) ? T() : T(i);
504 v->push_back(vi);
505 }
506 }
507
508 template <>
VectorWithConstantTail(int size,int tail_length,std::vector<std::complex<float>> * v)509 void VectorWithConstantTail(int size, int tail_length,
510 std::vector<std::complex<float>>* v) {
511 CHECK_LE(tail_length, size);
512 v->clear();
513 for (int i = 0; i < size; ++i) {
514 std::complex<float> vi(
515 0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
516 v->push_back(vi);
517 }
518 }
519
520 template <typename T>
CreateAsProtoTensorContent(int size,int tail_length)521 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
522 std::vector<T> values;
523 VectorWithConstantTail<T>(size, tail_length, &values);
524 Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
525 std::copy(values.begin(), values.end(), tensor.flat<T>().data());
526 TensorProto tensor_proto;
527 tensor.AsProtoTensorContent(&tensor_proto);
528 return tensor_proto;
529 }
530
531 template <typename T>
CreateAsProtoField(int size,int tail_length)532 TensorProto CreateAsProtoField(int size, int tail_length) {
533 std::vector<T> values;
534 VectorWithConstantTail<T>(size, tail_length, &values);
535 Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
536 std::copy(values.begin(), values.end(), tensor.flat<T>().data());
537 TensorProto tensor_proto;
538 tensor.AsProtoField(&tensor_proto);
539 return tensor_proto;
540 }
541
542 template <typename T>
CompareTensorValues(const TensorProto & x,const TensorProto & y)543 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
544 Tensor x_t;
545 EXPECT_TRUE(x_t.FromProto(x));
546 Tensor y_t;
547 EXPECT_TRUE(y_t.FromProto(y));
548 test::ExpectTensorEqual<T>(x_t, y_t);
549 }
550
551 template <typename T>
ConstantTailTest(int64 length,int64 tail_length,bool as_field)552 void ConstantTailTest(int64 length, int64 tail_length, bool as_field) {
553 using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
554 using FieldType = typename TensorProtoHelper::FieldType;
555 const float kMinCompressionRatio = 2.0;
556 const int64 kMinSize = 64;
557 TensorProto tensor_proto =
558 as_field ? CreateAsProtoField<T>(length, tail_length)
559 : CreateAsProtoTensorContent<T>(length, tail_length);
560 TensorProto original_tensor_proto = tensor_proto;
561 int64 original_size =
562 length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
563 : sizeof(T));
564 int64 size_as_tensor_content = length * sizeof(T);
565 int64 size_as_field = std::min(length, (length - tail_length + 1)) *
566 (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
567 bool will_compress = std::min(size_as_tensor_content, size_as_field) <=
568 static_cast<int64>(original_size / kMinCompressionRatio);
569
570 EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
571 &tensor_proto),
572 will_compress);
573 if (will_compress) {
574 if (size_as_tensor_content < size_as_field) {
575 EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
576 EXPECT_FALSE(tensor_proto.tensor_content().empty());
577 } else {
578 EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
579 (length - tail_length + 1));
580 EXPECT_TRUE(tensor_proto.tensor_content().empty());
581 }
582 }
583 CompareTensorValues<T>(tensor_proto, original_tensor_proto);
584 }
585
TEST(TensorProtoUtil,CompressTensorProtoConstantTail)586 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
587 const int kLength = 64;
588 for (bool as_field : {true, false}) {
589 for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
590 ConstantTailTest<float>(kLength, tail_length, as_field);
591 ConstantTailTest<double>(kLength, tail_length, as_field);
592 ConstantTailTest<complex64>(kLength, tail_length, as_field);
593 ConstantTailTest<complex128>(kLength, tail_length, as_field);
594 ConstantTailTest<int32>(kLength, tail_length, as_field);
595 ConstantTailTest<uint32>(kLength, tail_length, as_field);
596 ConstantTailTest<int64>(kLength, tail_length, as_field);
597 ConstantTailTest<uint64>(kLength, tail_length, as_field);
598 ConstantTailTest<int8>(kLength, tail_length, as_field);
599 ConstantTailTest<uint8>(kLength, tail_length, as_field);
600 ConstantTailTest<int16>(kLength, tail_length, as_field);
601 ConstantTailTest<uint16>(kLength, tail_length, as_field);
602 ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
603 ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
604 }
605 }
606 }
607
608 } // namespace
609 } // namespace tensorflow
610