• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_util.h"
17 
18 #include <vector>
19 #include "tensorflow/core/framework/tensor.h"
20 #include "tensorflow/core/framework/tensor_testutil.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/framework/variant.h"
23 #include "tensorflow/core/framework/variant_encode_decode.h"
24 #include "tensorflow/core/framework/variant_tensor_data.h"
25 #include "tensorflow/core/lib/core/status_test_util.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace tensorflow {
29 namespace {
30 
TEST(TensorUtil,DeepCopy0d)31 TEST(TensorUtil, DeepCopy0d) {
32   Tensor x(DT_FLOAT, TensorShape({}));
33   x.scalar<float>()() = 10.0;
34 
35   // Make y a deep copy of x and then change it.
36   Tensor y = tensor::DeepCopy(x);
37   y.scalar<float>()() = 20.0;
38 
39   // x doesn't change
40   EXPECT_EQ(10.0, x.scalar<float>()());
41 
42   // Change x.
43   x.scalar<float>()() = 30.0;
44 
45   // Y doesn't change.
46   EXPECT_EQ(20.0, y.scalar<float>()());
47 
48   Tensor z = tensor::DeepCopy(y);
49 
50   // Change y.
51   y.scalar<float>()() = 40.0;
52 
53   // The final states should all be different.
54   EXPECT_EQ(20.0, z.scalar<float>()());
55   EXPECT_EQ(30.0, x.scalar<float>()());
56   EXPECT_EQ(40.0, y.scalar<float>()());
57 
58   // Should have the same shape and type.
59   EXPECT_EQ(TensorShape({}), x.shape());
60   EXPECT_EQ(TensorShape({}), y.shape());
61   EXPECT_EQ(TensorShape({}), z.shape());
62 
63   EXPECT_EQ(DT_FLOAT, x.dtype());
64   EXPECT_EQ(DT_FLOAT, y.dtype());
65   EXPECT_EQ(DT_FLOAT, z.dtype());
66 }
67 
TEST(TensorUtil,DeepCopyZeroElements)68 TEST(TensorUtil, DeepCopyZeroElements) {
69   Tensor x;
70   Tensor y = tensor::DeepCopy(x);
71   EXPECT_EQ(TensorShape({0}), y.shape());
72   EXPECT_EQ(DT_FLOAT, y.dtype());
73   EXPECT_EQ(0, y.NumElements());
74 }
75 
TEST(TensorUtil,DeepCopy)76 TEST(TensorUtil, DeepCopy) {
77   Tensor x(DT_FLOAT, TensorShape({1}));
78   x.flat<float>()(0) = 10.0;
79 
80   // Make y a deep copy of x and then change it.
81   Tensor y = tensor::DeepCopy(x);
82   y.flat<float>()(0) = 20.0;
83 
84   // x doesn't change
85   EXPECT_EQ(10.0, x.flat<float>()(0));
86 
87   // Change x.
88   x.flat<float>()(0) = 30.0;
89 
90   // Y doesn't change.
91   EXPECT_EQ(20.0, y.flat<float>()(0));
92 
93   Tensor z = tensor::DeepCopy(y);
94 
95   // Change y.
96   y.flat<float>()(0) = 40.0;
97 
98   // The final states should all be different.
99   EXPECT_EQ(20.0, z.flat<float>()(0));
100   EXPECT_EQ(30.0, x.flat<float>()(0));
101   EXPECT_EQ(40.0, y.flat<float>()(0));
102 
103   // Should have the same shape and type.
104   EXPECT_EQ(TensorShape({1}), x.shape());
105   EXPECT_EQ(TensorShape({1}), y.shape());
106   EXPECT_EQ(TensorShape({1}), z.shape());
107 
108   EXPECT_EQ(DT_FLOAT, x.dtype());
109   EXPECT_EQ(DT_FLOAT, y.dtype());
110   EXPECT_EQ(DT_FLOAT, z.dtype());
111 
112   // Test string deep copy
113   Tensor str1(DT_STRING, TensorShape({2}));
114   str1.flat<tstring>()(0) = "foo1";
115   str1.flat<tstring>()(1) = "foo2";
116   Tensor str2 = tensor::DeepCopy(str1);
117   str2.flat<tstring>()(0) = "bar1";
118   str2.flat<tstring>()(1) = "bar2";
119   EXPECT_NE(str2.flat<tstring>()(0), str1.flat<tstring>()(0));
120 }
121 
TEST(TensorUtil,DeepCopySlice)122 TEST(TensorUtil, DeepCopySlice) {
123   Tensor x(DT_INT32, TensorShape({10}));
124   x.flat<int32>().setConstant(1);
125 
126   // Slice 'x' -- y still refers to the same buffer.
127   Tensor y = x.Slice(2, 6);
128 
129   // Do a deep copy of y, which is a slice.
130   Tensor z = tensor::DeepCopy(y);
131 
132   // Set x to be different.
133   x.flat<int32>().setConstant(2);
134 
135   EXPECT_EQ(TensorShape({10}), x.shape());
136   EXPECT_EQ(TensorShape({4}), y.shape());
137   EXPECT_EQ(TensorShape({4}), z.shape());
138   EXPECT_EQ(DT_INT32, x.dtype());
139   EXPECT_EQ(DT_INT32, y.dtype());
140   EXPECT_EQ(DT_INT32, z.dtype());
141 
142   // x and y should now all be '2', but z should be '1'.
143   for (int i = 0; i < 10; ++i) {
144     EXPECT_EQ(2, x.flat<int32>()(i));
145   }
146   for (int i = 0; i < 4; ++i) {
147     EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
148     EXPECT_EQ(1, z.flat<int32>()(i));
149   }
150 }
151 
TEST(TensorUtil,DeepCopySliceString)152 TEST(TensorUtil, DeepCopySliceString) {
153   Tensor x(DT_STRING, TensorShape({10}));
154   x.flat<tstring>().setConstant("hello");
155 
156   // Slice 'x' -- y still refers to the same buffer.
157   Tensor y = x.Slice(3, 7);
158 
159   // Do a deep copy of y, which is a slice.
160   Tensor z = tensor::DeepCopy(y);
161 
162   // Set x to be different.
163   x.flat<tstring>().setConstant("goodbye");
164 
165   EXPECT_EQ(TensorShape({10}), x.shape());
166   EXPECT_EQ(TensorShape({4}), y.shape());
167   EXPECT_EQ(TensorShape({4}), z.shape());
168   EXPECT_EQ(DT_STRING, x.dtype());
169   EXPECT_EQ(DT_STRING, y.dtype());
170   EXPECT_EQ(DT_STRING, z.dtype());
171 
172   // x and y should now all be 'goodbye', but z should be 'hello'.
173   for (int i = 0; i < 10; ++i) {
174     EXPECT_EQ("goodbye", x.flat<tstring>()(i));
175   }
176   for (int i = 0; i < 4; ++i) {
177     EXPECT_EQ("goodbye", y.unaligned_flat<tstring>()(i));
178     EXPECT_EQ("hello", z.flat<tstring>()(i));
179   }
180 }
181 
TEST(TensorUtil,DeepCopySliceVariant)182 TEST(TensorUtil, DeepCopySliceVariant) {
183   Tensor x(DT_VARIANT, TensorShape({10}));
184   x.flat<Variant>().setConstant(Tensor(42.0f));
185 
186   // Slice 'x' -- y still refers to the same buffer.
187   Tensor y = x.Slice(3, 7);
188 
189   // Do a deep copy of y, which is a slice.
190   Tensor z = tensor::DeepCopy(y);
191 
192   // Set x to be different.
193   x.flat<Variant>().setConstant(Tensor("foo"));
194 
195   EXPECT_EQ(TensorShape({10}), x.shape());
196   EXPECT_EQ(TensorShape({4}), y.shape());
197   EXPECT_EQ(TensorShape({4}), z.shape());
198   EXPECT_EQ(DT_VARIANT, x.dtype());
199   EXPECT_EQ(DT_VARIANT, y.dtype());
200   EXPECT_EQ(DT_VARIANT, z.dtype());
201 
202   // Each element of x and y should now be a DT_STRING Tensor containing "foo",
203   // but each element of z should be a DT_FLOAT tensor containing 42.0.
204   for (int i = 0; i < 10; ++i) {
205     EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
206   }
207   for (int i = 0; i < 4; ++i) {
208     EXPECT_EQ(
209         "foo",
210         y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
211     EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
212   }
213 }
214 
TEST(TensorUtil,Concat)215 TEST(TensorUtil, Concat) {
216   std::vector<int64> sizes = {1, 4, 5};
217   std::vector<Tensor> to_concat;
218   int64 total_size = 0;
219   int offset = 0;
220   for (size_t entry = 0; entry < sizes.size(); ++entry) {
221     const int64 size = sizes[entry];
222     Tensor tensor(DT_INT32, TensorShape({size, 2}));
223     for (int i = offset; i < offset + size; ++i) {
224       for (int j = 0; j < 2; ++j) {
225         tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
226       }
227     }
228     to_concat.push_back(tensor);
229     total_size += size;
230     offset += size;
231   }
232 
233   Tensor concated;
234   TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
235   ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
236   for (int i = 0; i < total_size; ++i) {
237     for (int j = 0; j < 2; ++j) {
238       EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
239     }
240   }
241 }
242 
TEST(TensorUtil,Split)243 TEST(TensorUtil, Split) {
244   Tensor to_split(DT_INT64, TensorShape({10, 2}));
245   for (int i = 0; i < 10; ++i) {
246     for (int j = 0; j < 2; ++j) {
247       to_split.matrix<int64>()(i, j) = 2 * i + j;
248     }
249   }
250 
251   std::vector<int64> sizes = {1, 4, 5};
252   std::vector<Tensor> splits;
253   TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
254   ASSERT_EQ(sizes.size(), splits.size());
255 
256   int offset = 0;
257   for (size_t entry = 0; entry < splits.size(); ++entry) {
258     const int64 size = sizes[entry];
259     const Tensor& split = splits[entry];
260 
261     ASSERT_EQ(TensorShape({size, 2}), split.shape());
262     for (int i = offset; i < offset + size; ++i) {
263       for (int j = 0; j < 2; ++j) {
264         EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
265       }
266     }
267 
268     offset += size;
269   }
270 }
271 
TEST(TensorUtil,ConcatSplitStrings)272 TEST(TensorUtil, ConcatSplitStrings) {
273   Tensor x(DT_STRING, TensorShape({4, 3}));
274   for (int i = 0; i < 4 * 3; ++i) {
275     x.flat<tstring>()(i) = strings::StrCat("foo_", i);
276   }
277 
278   std::vector<Tensor> split;
279   TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
280   Tensor x_round_tripped;
281   TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
282   ASSERT_EQ(x.shape(), x_round_tripped.shape());
283   for (int i = 0; i < 4 * 3; ++i) {
284     EXPECT_EQ(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
285   }
286 
287   // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
288   for (int i = 0; i < 4 * 3; ++i) {
289     x_round_tripped.flat<tstring>()(i) = strings::StrCat("bar_", i);
290   }
291   for (int i = 0; i < 4 * 3; ++i) {
292     EXPECT_NE(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
293   }
294 }
295 
TEST(TensorProtoUtil,CreatesStringTensorProto)296 TEST(TensorProtoUtil, CreatesStringTensorProto) {
297   std::vector<string> values{"a", "b", "c"};
298   std::vector<size_t> shape{1, 3};
299 
300   auto proto = tensor::CreateTensorProto(values, shape);
301 
302   EXPECT_EQ(proto.DebugString(),
303             "dtype: DT_STRING\n"
304             "tensor_shape {\n"
305             "  dim {\n"
306             "    size: 1\n"
307             "  }\n"
308             "  dim {\n"
309             "    size: 3\n"
310             "  }\n"
311             "}\n"
312             "string_val: \"a\"\n"
313             "string_val: \"b\"\n"
314             "string_val: \"c\"\n");
315 }
316 
TEST(TensorProtoUtil,CreatesInt32TensorProto)317 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
318   std::vector<int32> values{1, 2};
319   std::vector<size_t> shape{2};
320 
321   auto proto = tensor::CreateTensorProto(values, shape);
322 
323   EXPECT_EQ(proto.DebugString(),
324             "dtype: DT_INT32\n"
325             "tensor_shape {\n"
326             "  dim {\n"
327             "    size: 2\n"
328             "  }\n"
329             "}\n"
330             "int_val: 1\n"
331             "int_val: 2\n");
332 }
333 
TEST(TensorProtoUtil,CreatesInt64TensorProto)334 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
335   std::vector<int64> values{1, 2};
336   std::vector<size_t> shape{2};
337 
338   auto proto = tensor::CreateTensorProto(values, shape);
339 
340   EXPECT_EQ(proto.DebugString(),
341             "dtype: DT_INT64\n"
342             "tensor_shape {\n"
343             "  dim {\n"
344             "    size: 2\n"
345             "  }\n"
346             "}\n"
347             "int64_val: 1\n"
348             "int64_val: 2\n");
349 }
350 
TEST(TensorProtoUtil,CreatesUInt32TensorProto)351 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
352   std::vector<uint32> values{1, 2};
353   std::vector<size_t> shape{2};
354 
355   auto proto = tensor::CreateTensorProto(values, shape);
356 
357   EXPECT_EQ(proto.DebugString(),
358             "dtype: DT_UINT32\n"
359             "tensor_shape {\n"
360             "  dim {\n"
361             "    size: 2\n"
362             "  }\n"
363             "}\n"
364             "uint32_val: 1\n"
365             "uint32_val: 2\n");
366 }
367 
TEST(TensorProtoUtil,CreatesUInt64TensorProto)368 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
369   std::vector<uint64> values{1, 2};
370   std::vector<size_t> shape{2};
371 
372   auto proto = tensor::CreateTensorProto(values, shape);
373 
374   EXPECT_EQ(proto.DebugString(),
375             "dtype: DT_UINT64\n"
376             "tensor_shape {\n"
377             "  dim {\n"
378             "    size: 2\n"
379             "  }\n"
380             "}\n"
381             "uint64_val: 1\n"
382             "uint64_val: 2\n");
383 }
384 
TEST(TensorProtoUtil,CreatesFloatTensorProto)385 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
386   std::vector<float> values{1.1, 2.2};
387   std::vector<size_t> shape{2};
388 
389   auto proto = tensor::CreateTensorProto(values, shape);
390 
391   EXPECT_EQ(proto.DebugString(),
392             "dtype: DT_FLOAT\n"
393             "tensor_shape {\n"
394             "  dim {\n"
395             "    size: 2\n"
396             "  }\n"
397             "}\n"
398             "float_val: 1.1\n"
399             "float_val: 2.2\n");
400 }
401 
TEST(TensorProtoUtil,CreatesDoubleTensorProto)402 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
403   std::vector<double> values{1.1, 2.2};
404   std::vector<size_t> shape{2};
405 
406   auto proto = tensor::CreateTensorProto(values, shape);
407 
408   EXPECT_EQ(proto.DebugString(),
409             "dtype: DT_DOUBLE\n"
410             "tensor_shape {\n"
411             "  dim {\n"
412             "    size: 2\n"
413             "  }\n"
414             "}\n"
415             "double_val: 1.1\n"
416             "double_val: 2.2\n");
417 }
418 
TEST(TensorProtoUtil,CreatesBoolTensorProto)419 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
420   std::vector<bool> values{true, false};
421   std::vector<size_t> shape{2};
422 
423   auto proto = tensor::CreateTensorProto(values, shape);
424 
425   EXPECT_EQ(proto.DebugString(),
426             "dtype: DT_BOOL\n"
427             "tensor_shape {\n"
428             "  dim {\n"
429             "    size: 2\n"
430             "  }\n"
431             "}\n"
432             "bool_val: true\n"
433             "bool_val: false\n");
434 }
435 
TEST(TensorProtoUtil,CompressTensorProtoInPlaceTooSmall)436 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
437   const int kLength = 63;
438   TensorProto tensor_proto =
439       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
440   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
441   tensor_proto =
442       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
443   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
444   tensor_proto =
445       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
446   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
447   tensor_proto =
448       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
449   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
450   tensor_proto =
451       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
452   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
453   tensor_proto = tensor::CreateTensorProto(
454       std::vector<std::complex<float>>(kLength), {kLength});
455   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
456 }
457 
TEST(TensorProtoUtil,CompressTensorProtoInPlaceAllEqual)458 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
459   const int kLength = 64;
460   TensorProto tensor_proto =
461       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
462   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
463   EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
464             1);
465 
466   tensor_proto =
467       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
468   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
469   EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
470             1);
471 
472   tensor_proto =
473       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
474   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
475   EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
476             1);
477   tensor_proto =
478       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
479   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
480   EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
481             1);
482 
483   tensor_proto =
484       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
485   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
486   EXPECT_EQ(
487       tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
488       1);
489 
490   tensor_proto = tensor::CreateTensorProto(
491       std::vector<std::complex<float>>(kLength), {kLength});
492   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
493   EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
494                 tensor_proto),
495             1);
496 }
497 
498 template <typename T>
VectorWithConstantTail(int size,int tail_length,std::vector<T> * v)499 void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
500   CHECK_LE(tail_length, size);
501   v->clear();
502   for (int i = 0; i < size; ++i) {
503     T vi = (i >= size - tail_length) ? T() : T(i);
504     v->push_back(vi);
505   }
506 }
507 
508 template <>
VectorWithConstantTail(int size,int tail_length,std::vector<std::complex<float>> * v)509 void VectorWithConstantTail(int size, int tail_length,
510                             std::vector<std::complex<float>>* v) {
511   CHECK_LE(tail_length, size);
512   v->clear();
513   for (int i = 0; i < size; ++i) {
514     std::complex<float> vi(
515         0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
516     v->push_back(vi);
517   }
518 }
519 
520 template <typename T>
CreateAsProtoTensorContent(int size,int tail_length)521 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
522   std::vector<T> values;
523   VectorWithConstantTail<T>(size, tail_length, &values);
524   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
525   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
526   TensorProto tensor_proto;
527   tensor.AsProtoTensorContent(&tensor_proto);
528   return tensor_proto;
529 }
530 
531 template <typename T>
CreateAsProtoField(int size,int tail_length)532 TensorProto CreateAsProtoField(int size, int tail_length) {
533   std::vector<T> values;
534   VectorWithConstantTail<T>(size, tail_length, &values);
535   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
536   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
537   TensorProto tensor_proto;
538   tensor.AsProtoField(&tensor_proto);
539   return tensor_proto;
540 }
541 
542 template <typename T>
CompareTensorValues(const TensorProto & x,const TensorProto & y)543 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
544   Tensor x_t;
545   EXPECT_TRUE(x_t.FromProto(x));
546   Tensor y_t;
547   EXPECT_TRUE(y_t.FromProto(y));
548   test::ExpectTensorEqual<T>(x_t, y_t);
549 }
550 
551 template <typename T>
ConstantTailTest(int64 length,int64 tail_length,bool as_field)552 void ConstantTailTest(int64 length, int64 tail_length, bool as_field) {
553   using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
554   using FieldType = typename TensorProtoHelper::FieldType;
555   const float kMinCompressionRatio = 2.0;
556   const int64 kMinSize = 64;
557   TensorProto tensor_proto =
558       as_field ? CreateAsProtoField<T>(length, tail_length)
559                : CreateAsProtoTensorContent<T>(length, tail_length);
560   TensorProto original_tensor_proto = tensor_proto;
561   int64 original_size =
562       length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
563                          : sizeof(T));
564   int64 size_as_tensor_content = length * sizeof(T);
565   int64 size_as_field = std::min(length, (length - tail_length + 1)) *
566                         (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
567   bool will_compress = std::min(size_as_tensor_content, size_as_field) <=
568                        static_cast<int64>(original_size / kMinCompressionRatio);
569 
570   EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
571                                                &tensor_proto),
572             will_compress);
573   if (will_compress) {
574     if (size_as_tensor_content < size_as_field) {
575       EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
576       EXPECT_FALSE(tensor_proto.tensor_content().empty());
577     } else {
578       EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
579                 (length - tail_length + 1));
580       EXPECT_TRUE(tensor_proto.tensor_content().empty());
581     }
582   }
583   CompareTensorValues<T>(tensor_proto, original_tensor_proto);
584 }
585 
TEST(TensorProtoUtil,CompressTensorProtoConstantTail)586 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
587   const int kLength = 64;
588   for (bool as_field : {true, false}) {
589     for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
590       ConstantTailTest<float>(kLength, tail_length, as_field);
591       ConstantTailTest<double>(kLength, tail_length, as_field);
592       ConstantTailTest<complex64>(kLength, tail_length, as_field);
593       ConstantTailTest<complex128>(kLength, tail_length, as_field);
594       ConstantTailTest<int32>(kLength, tail_length, as_field);
595       ConstantTailTest<uint32>(kLength, tail_length, as_field);
596       ConstantTailTest<int64>(kLength, tail_length, as_field);
597       ConstantTailTest<uint64>(kLength, tail_length, as_field);
598       ConstantTailTest<int8>(kLength, tail_length, as_field);
599       ConstantTailTest<uint8>(kLength, tail_length, as_field);
600       ConstantTailTest<int16>(kLength, tail_length, as_field);
601       ConstantTailTest<uint16>(kLength, tail_length, as_field);
602       ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
603       ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
604     }
605   }
606 }
607 
608 }  // namespace
609 }  // namespace tensorflow
610