• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_util.h"
17 
18 #include <vector>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_encode_decode.h"
25 #include "tensorflow/core/framework/variant_tensor_data.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace {
31 
TEST(TensorUtil,DeepCopy0d)32 TEST(TensorUtil, DeepCopy0d) {
33   Tensor x(DT_FLOAT, TensorShape({}));
34   x.scalar<float>()() = 10.0;
35 
36   // Make y a deep copy of x and then change it.
37   Tensor y = tensor::DeepCopy(x);
38   y.scalar<float>()() = 20.0;
39 
40   // x doesn't change
41   EXPECT_EQ(10.0, x.scalar<float>()());
42 
43   // Change x.
44   x.scalar<float>()() = 30.0;
45 
46   // Y doesn't change.
47   EXPECT_EQ(20.0, y.scalar<float>()());
48 
49   Tensor z = tensor::DeepCopy(y);
50 
51   // Change y.
52   y.scalar<float>()() = 40.0;
53 
54   // The final states should all be different.
55   EXPECT_EQ(20.0, z.scalar<float>()());
56   EXPECT_EQ(30.0, x.scalar<float>()());
57   EXPECT_EQ(40.0, y.scalar<float>()());
58 
59   // Should have the same shape and type.
60   EXPECT_EQ(TensorShape({}), x.shape());
61   EXPECT_EQ(TensorShape({}), y.shape());
62   EXPECT_EQ(TensorShape({}), z.shape());
63 
64   EXPECT_EQ(DT_FLOAT, x.dtype());
65   EXPECT_EQ(DT_FLOAT, y.dtype());
66   EXPECT_EQ(DT_FLOAT, z.dtype());
67 }
68 
TEST(TensorUtil,DeepCopyZeroElements)69 TEST(TensorUtil, DeepCopyZeroElements) {
70   Tensor x;
71   Tensor y = tensor::DeepCopy(x);
72   EXPECT_EQ(TensorShape({0}), y.shape());
73   EXPECT_EQ(DT_FLOAT, y.dtype());
74   EXPECT_EQ(0, y.NumElements());
75 }
76 
TEST(TensorUtil,DeepCopy)77 TEST(TensorUtil, DeepCopy) {
78   Tensor x(DT_FLOAT, TensorShape({1}));
79   x.flat<float>()(0) = 10.0;
80 
81   // Make y a deep copy of x and then change it.
82   Tensor y = tensor::DeepCopy(x);
83   y.flat<float>()(0) = 20.0;
84 
85   // x doesn't change
86   EXPECT_EQ(10.0, x.flat<float>()(0));
87 
88   // Change x.
89   x.flat<float>()(0) = 30.0;
90 
91   // Y doesn't change.
92   EXPECT_EQ(20.0, y.flat<float>()(0));
93 
94   Tensor z = tensor::DeepCopy(y);
95 
96   // Change y.
97   y.flat<float>()(0) = 40.0;
98 
99   // The final states should all be different.
100   EXPECT_EQ(20.0, z.flat<float>()(0));
101   EXPECT_EQ(30.0, x.flat<float>()(0));
102   EXPECT_EQ(40.0, y.flat<float>()(0));
103 
104   // Should have the same shape and type.
105   EXPECT_EQ(TensorShape({1}), x.shape());
106   EXPECT_EQ(TensorShape({1}), y.shape());
107   EXPECT_EQ(TensorShape({1}), z.shape());
108 
109   EXPECT_EQ(DT_FLOAT, x.dtype());
110   EXPECT_EQ(DT_FLOAT, y.dtype());
111   EXPECT_EQ(DT_FLOAT, z.dtype());
112 
113   // Test string deep copy
114   Tensor str1(DT_STRING, TensorShape({2}));
115   str1.flat<tstring>()(0) = "foo1";
116   str1.flat<tstring>()(1) = "foo2";
117   Tensor str2 = tensor::DeepCopy(str1);
118   str2.flat<tstring>()(0) = "bar1";
119   str2.flat<tstring>()(1) = "bar2";
120   EXPECT_NE(str2.flat<tstring>()(0), str1.flat<tstring>()(0));
121 }
122 
TEST(TensorUtil,DeepCopySlice)123 TEST(TensorUtil, DeepCopySlice) {
124   Tensor x(DT_INT32, TensorShape({10}));
125   x.flat<int32>().setConstant(1);
126 
127   // Slice 'x' -- y still refers to the same buffer.
128   Tensor y = x.Slice(2, 6);
129 
130   // Do a deep copy of y, which is a slice.
131   Tensor z = tensor::DeepCopy(y);
132 
133   // Set x to be different.
134   x.flat<int32>().setConstant(2);
135 
136   EXPECT_EQ(TensorShape({10}), x.shape());
137   EXPECT_EQ(TensorShape({4}), y.shape());
138   EXPECT_EQ(TensorShape({4}), z.shape());
139   EXPECT_EQ(DT_INT32, x.dtype());
140   EXPECT_EQ(DT_INT32, y.dtype());
141   EXPECT_EQ(DT_INT32, z.dtype());
142 
143   // x and y should now all be '2', but z should be '1'.
144   for (int i = 0; i < 10; ++i) {
145     EXPECT_EQ(2, x.flat<int32>()(i));
146   }
147   for (int i = 0; i < 4; ++i) {
148     EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
149     EXPECT_EQ(1, z.flat<int32>()(i));
150   }
151 }
152 
TEST(TensorUtil,DeepCopySliceString)153 TEST(TensorUtil, DeepCopySliceString) {
154   Tensor x(DT_STRING, TensorShape({10}));
155   x.flat<tstring>().setConstant("hello");
156 
157   // Slice 'x' -- y still refers to the same buffer.
158   Tensor y = x.Slice(3, 7);
159 
160   // Do a deep copy of y, which is a slice.
161   Tensor z = tensor::DeepCopy(y);
162 
163   // Set x to be different.
164   x.flat<tstring>().setConstant("goodbye");
165 
166   EXPECT_EQ(TensorShape({10}), x.shape());
167   EXPECT_EQ(TensorShape({4}), y.shape());
168   EXPECT_EQ(TensorShape({4}), z.shape());
169   EXPECT_EQ(DT_STRING, x.dtype());
170   EXPECT_EQ(DT_STRING, y.dtype());
171   EXPECT_EQ(DT_STRING, z.dtype());
172 
173   // x and y should now all be 'goodbye', but z should be 'hello'.
174   for (int i = 0; i < 10; ++i) {
175     EXPECT_EQ("goodbye", x.flat<tstring>()(i));
176   }
177   for (int i = 0; i < 4; ++i) {
178     EXPECT_EQ("goodbye", y.unaligned_flat<tstring>()(i));
179     EXPECT_EQ("hello", z.flat<tstring>()(i));
180   }
181 }
182 
TEST(TensorUtil,DeepCopySliceVariant)183 TEST(TensorUtil, DeepCopySliceVariant) {
184   Tensor x(DT_VARIANT, TensorShape({10}));
185   x.flat<Variant>().setConstant(Tensor(42.0f));
186 
187   // Slice 'x' -- y still refers to the same buffer.
188   Tensor y = x.Slice(3, 7);
189 
190   // Do a deep copy of y, which is a slice.
191   Tensor z = tensor::DeepCopy(y);
192 
193   // Set x to be different.
194   x.flat<Variant>().setConstant(Tensor("foo"));
195 
196   EXPECT_EQ(TensorShape({10}), x.shape());
197   EXPECT_EQ(TensorShape({4}), y.shape());
198   EXPECT_EQ(TensorShape({4}), z.shape());
199   EXPECT_EQ(DT_VARIANT, x.dtype());
200   EXPECT_EQ(DT_VARIANT, y.dtype());
201   EXPECT_EQ(DT_VARIANT, z.dtype());
202 
203   // Each element of x and y should now be a DT_STRING Tensor containing "foo",
204   // but each element of z should be a DT_FLOAT tensor containing 42.0.
205   for (int i = 0; i < 10; ++i) {
206     EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
207   }
208   for (int i = 0; i < 4; ++i) {
209     EXPECT_EQ(
210         "foo",
211         y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
212     EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
213   }
214 }
215 
TEST(TensorUtil,Concat)216 TEST(TensorUtil, Concat) {
217   std::vector<int64_t> sizes = {1, 4, 5};
218   std::vector<Tensor> to_concat;
219   int64_t total_size = 0;
220   int offset = 0;
221   for (size_t entry = 0; entry < sizes.size(); ++entry) {
222     const int64_t size = sizes[entry];
223     Tensor tensor(DT_INT32, TensorShape({size, 2}));
224     for (int i = offset; i < offset + size; ++i) {
225       for (int j = 0; j < 2; ++j) {
226         tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
227       }
228     }
229     to_concat.push_back(tensor);
230     total_size += size;
231     offset += size;
232   }
233 
234   Tensor concated;
235   TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
236   ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
237   for (int i = 0; i < total_size; ++i) {
238     for (int j = 0; j < 2; ++j) {
239       EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
240     }
241   }
242 }
243 
TEST(TensorUtil,Split)244 TEST(TensorUtil, Split) {
245   Tensor to_split(DT_INT64, TensorShape({10, 2}));
246   for (int i = 0; i < 10; ++i) {
247     for (int j = 0; j < 2; ++j) {
248       to_split.matrix<int64_t>()(i, j) = 2 * i + j;
249     }
250   }
251 
252   std::vector<int64_t> sizes = {1, 4, 5};
253   std::vector<Tensor> splits;
254   TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
255   ASSERT_EQ(sizes.size(), splits.size());
256 
257   int offset = 0;
258   for (size_t entry = 0; entry < splits.size(); ++entry) {
259     const int64_t size = sizes[entry];
260     const Tensor& split = splits[entry];
261 
262     ASSERT_EQ(TensorShape({size, 2}), split.shape());
263     for (int i = offset; i < offset + size; ++i) {
264       for (int j = 0; j < 2; ++j) {
265         EXPECT_EQ(2 * i + j, split.matrix<int64_t>()(i - offset, j));
266       }
267     }
268 
269     offset += size;
270   }
271 }
272 
TEST(TensorUtil,ConcatSplitStrings)273 TEST(TensorUtil, ConcatSplitStrings) {
274   Tensor x(DT_STRING, TensorShape({4, 3}));
275   for (int i = 0; i < 4 * 3; ++i) {
276     x.flat<tstring>()(i) = strings::StrCat("foo_", i);
277   }
278 
279   std::vector<Tensor> split;
280   TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
281   Tensor x_round_tripped;
282   TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
283   ASSERT_EQ(x.shape(), x_round_tripped.shape());
284   for (int i = 0; i < 4 * 3; ++i) {
285     EXPECT_EQ(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
286   }
287 
288   // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
289   for (int i = 0; i < 4 * 3; ++i) {
290     x_round_tripped.flat<tstring>()(i) = strings::StrCat("bar_", i);
291   }
292   for (int i = 0; i < 4 * 3; ++i) {
293     EXPECT_NE(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
294   }
295 }
296 
TEST(TensorProtoUtil,CreatesStringTensorProto)297 TEST(TensorProtoUtil, CreatesStringTensorProto) {
298   std::vector<string> values{"a", "b", "c"};
299   std::vector<size_t> shape{1, 3};
300 
301   auto proto = tensor::CreateTensorProto(values, shape);
302 
303   TensorProto expected_tensor_proto;
304   protobuf::TextFormat::ParseFromString(
305       "dtype: DT_STRING\n"
306       "tensor_shape {\n"
307       "  dim {\n"
308       "    size: 1\n"
309       "  }\n"
310       "  dim {\n"
311       "    size: 3\n"
312       "  }\n"
313       "}\n"
314       "string_val: \"a\"\n"
315       "string_val: \"b\"\n"
316       "string_val: \"c\"\n",
317       &expected_tensor_proto);
318   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
319 }
320 
TEST(TensorProtoUtil,CreatesInt32TensorProto)321 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
322   std::vector<int32> values{1, 2};
323   std::vector<size_t> shape{2};
324 
325   auto proto = tensor::CreateTensorProto(values, shape);
326 
327   TensorProto expected_tensor_proto;
328   protobuf::TextFormat::ParseFromString(
329       "dtype: DT_INT32\n"
330       "tensor_shape {\n"
331       "  dim {\n"
332       "    size: 2\n"
333       "  }\n"
334       "}\n"
335       "int_val: 1\n"
336       "int_val: 2\n",
337       &expected_tensor_proto);
338   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
339 }
340 
TEST(TensorProtoUtil,CreatesInt64TensorProto)341 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
342   std::vector<int64_t> values{1, 2};
343   std::vector<size_t> shape{2};
344 
345   auto proto = tensor::CreateTensorProto(values, shape);
346 
347   TensorProto expected_tensor_proto;
348   protobuf::TextFormat::ParseFromString(
349       "dtype: DT_INT64\n"
350       "tensor_shape {\n"
351       "  dim {\n"
352       "    size: 2\n"
353       "  }\n"
354       "}\n"
355       "int64_val: 1\n"
356       "int64_val: 2\n",
357       &expected_tensor_proto);
358   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
359 }
360 
TEST(TensorProtoUtil,CreatesUInt32TensorProto)361 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
362   std::vector<uint32> values{1, 2};
363   std::vector<size_t> shape{2};
364 
365   auto proto = tensor::CreateTensorProto(values, shape);
366 
367   TensorProto expected_tensor_proto;
368   protobuf::TextFormat::ParseFromString(
369       "dtype: DT_UINT32\n"
370       "tensor_shape {\n"
371       "  dim {\n"
372       "    size: 2\n"
373       "  }\n"
374       "}\n"
375       "uint32_val: 1\n"
376       "uint32_val: 2\n",
377       &expected_tensor_proto);
378   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
379 }
380 
TEST(TensorProtoUtil,CreatesUInt64TensorProto)381 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
382   std::vector<uint64> values{1, 2};
383   std::vector<size_t> shape{2};
384 
385   auto proto = tensor::CreateTensorProto(values, shape);
386 
387   TensorProto expected_tensor_proto;
388   protobuf::TextFormat::ParseFromString(
389       "dtype: DT_UINT64\n"
390       "tensor_shape {\n"
391       "  dim {\n"
392       "    size: 2\n"
393       "  }\n"
394       "}\n"
395       "uint64_val: 1\n"
396       "uint64_val: 2\n",
397       &expected_tensor_proto);
398   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
399 }
400 
TEST(TensorProtoUtil,CreatesFloatTensorProto)401 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
402   std::vector<float> values{1.1, 2.2};
403   std::vector<size_t> shape{2};
404 
405   auto proto = tensor::CreateTensorProto(values, shape);
406 
407   TensorProto expected_tensor_proto;
408   protobuf::TextFormat::ParseFromString(
409       "dtype: DT_FLOAT\n"
410       "tensor_shape {\n"
411       "  dim {\n"
412       "    size: 2\n"
413       "  }\n"
414       "}\n"
415       "float_val: 1.1\n"
416       "float_val: 2.2\n",
417       &expected_tensor_proto);
418   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
419 }
420 
TEST(TensorProtoUtil,CreatesDoubleTensorProto)421 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
422   std::vector<double> values{1.1, 2.2};
423   std::vector<size_t> shape{2};
424 
425   auto proto = tensor::CreateTensorProto(values, shape);
426 
427   TensorProto expected_tensor_proto;
428   protobuf::TextFormat::ParseFromString(
429       "dtype: DT_DOUBLE\n"
430       "tensor_shape {\n"
431       "  dim {\n"
432       "    size: 2\n"
433       "  }\n"
434       "}\n"
435       "double_val: 1.1\n"
436       "double_val: 2.2\n",
437       &expected_tensor_proto);
438   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
439 }
440 
TEST(TensorProtoUtil,CreatesBoolTensorProto)441 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
442   std::vector<bool> values{true, false};
443   std::vector<size_t> shape{2};
444 
445   auto proto = tensor::CreateTensorProto(values, shape);
446 
447   TensorProto expected_tensor_proto;
448   protobuf::TextFormat::ParseFromString(
449       "dtype: DT_BOOL\n"
450       "tensor_shape {\n"
451       "  dim {\n"
452       "    size: 2\n"
453       "  }\n"
454       "}\n"
455       "bool_val: true\n"
456       "bool_val: false\n",
457       &expected_tensor_proto);
458   EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
459 }
460 
TEST(TensorProtoUtil,CompressTensorProtoInPlaceTooSmall)461 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
462   const int kLength = 63;
463   TensorProto tensor_proto =
464       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
465   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
466   tensor_proto =
467       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
468   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
469   tensor_proto =
470       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
471   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
472   tensor_proto =
473       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
474   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
475   tensor_proto =
476       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
477   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
478   tensor_proto = tensor::CreateTensorProto(
479       std::vector<std::complex<float>>(kLength), {kLength});
480   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
481 }
482 
TEST(TensorProtoUtil,CompressTensorProtoInPlaceAllEqual)483 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
484   const int kLength = 64;
485   TensorProto tensor_proto =
486       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
487   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
488   EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
489             0);
490 
491   tensor_proto =
492       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
493   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
494   EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
495             0);
496 
497   tensor_proto =
498       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
499   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
500   EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
501             0);
502   tensor_proto =
503       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
504   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
505   EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
506             0);
507 
508   tensor_proto =
509       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
510   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
511   EXPECT_EQ(
512       tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
513       0);
514 
515   tensor_proto = tensor::CreateTensorProto(
516       std::vector<std::complex<float>>(kLength), {kLength});
517   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
518   EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
519                 tensor_proto),
520             0);
521 }
522 
523 template <typename T>
VectorWithConstantTail(int size,int tail_length,std::vector<T> * v)524 void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
525   CHECK_LE(tail_length, size);
526   v->clear();
527   for (int i = 0; i < size; ++i) {
528     T vi = (i >= size - tail_length) ? T() : T(i);
529     v->push_back(vi);
530   }
531 }
532 
533 template <>
VectorWithConstantTail(int size,int tail_length,std::vector<std::complex<float>> * v)534 void VectorWithConstantTail(int size, int tail_length,
535                             std::vector<std::complex<float>>* v) {
536   CHECK_LE(tail_length, size);
537   v->clear();
538   for (int i = 0; i < size; ++i) {
539     std::complex<float> vi(
540         0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
541     v->push_back(vi);
542   }
543 }
544 
545 template <typename T>
CreateAsProtoTensorContent(int size,int tail_length)546 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
547   std::vector<T> values;
548   VectorWithConstantTail<T>(size, tail_length, &values);
549   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
550   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
551   TensorProto tensor_proto;
552   tensor.AsProtoTensorContent(&tensor_proto);
553   return tensor_proto;
554 }
555 
556 template <typename T>
CreateAsProtoField(int size,int tail_length)557 TensorProto CreateAsProtoField(int size, int tail_length) {
558   std::vector<T> values;
559   VectorWithConstantTail<T>(size, tail_length, &values);
560   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
561   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
562   TensorProto tensor_proto;
563   tensor.AsProtoField(&tensor_proto);
564   return tensor_proto;
565 }
566 
567 template <typename T>
CompareTensorValues(const TensorProto & x,const TensorProto & y)568 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
569   Tensor x_t;
570   EXPECT_TRUE(x_t.FromProto(x));
571   Tensor y_t;
572   EXPECT_TRUE(y_t.FromProto(y));
573   test::ExpectTensorEqual<T>(x_t, y_t);
574 }
575 
576 template <typename T>
ConstantTailTest(int64_t length,int64_t tail_length,bool as_field)577 void ConstantTailTest(int64_t length, int64_t tail_length, bool as_field) {
578   using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
579   using FieldType = typename TensorProtoHelper::FieldType;
580   const float kMinCompressionRatio = 2.0;
581   const int64_t kMinSize = 64;
582   TensorProto tensor_proto =
583       as_field ? CreateAsProtoField<T>(length, tail_length)
584                : CreateAsProtoTensorContent<T>(length, tail_length);
585   TensorProto original_tensor_proto = tensor_proto;
586   int64_t original_size =
587       length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
588                          : sizeof(T));
589   int64_t size_as_tensor_content = length * sizeof(T);
590   int64_t size_as_field = std::min(length, (length - tail_length + 1)) *
591                           (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
592   bool will_compress =
593       std::min(size_as_tensor_content, size_as_field) <=
594       static_cast<int64_t>(original_size / kMinCompressionRatio);
595 
596   EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
597                                                &tensor_proto),
598             will_compress);
599   if (will_compress) {
600     if (size_as_tensor_content < size_as_field) {
601       EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
602       EXPECT_FALSE(tensor_proto.tensor_content().empty());
603     } else {
604       EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
605                 (length - tail_length + 1));
606       EXPECT_TRUE(tensor_proto.tensor_content().empty());
607     }
608   }
609   CompareTensorValues<T>(tensor_proto, original_tensor_proto);
610 }
611 
TEST(TensorProtoUtil,CompressTensorProtoConstantTail)612 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
613   const int kLength = 64;
614   for (bool as_field : {true, false}) {
615     for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
616       ConstantTailTest<float>(kLength, tail_length, as_field);
617       ConstantTailTest<double>(kLength, tail_length, as_field);
618       ConstantTailTest<complex64>(kLength, tail_length, as_field);
619       ConstantTailTest<complex128>(kLength, tail_length, as_field);
620       ConstantTailTest<int32>(kLength, tail_length, as_field);
621       ConstantTailTest<uint32>(kLength, tail_length, as_field);
622       ConstantTailTest<int64_t>(kLength, tail_length, as_field);
623       ConstantTailTest<uint64>(kLength, tail_length, as_field);
624       ConstantTailTest<int8>(kLength, tail_length, as_field);
625       ConstantTailTest<uint8>(kLength, tail_length, as_field);
626       ConstantTailTest<int16>(kLength, tail_length, as_field);
627       ConstantTailTest<uint16>(kLength, tail_length, as_field);
628       ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
629       ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
630     }
631   }
632 }
633 
TEST(TensorProtoUtil,CompressTensorProtoNegatizeZero)634 TEST(TensorProtoUtil, CompressTensorProtoNegatizeZero) {
635   TensorProto tensor_proto;
636   {
637     // Double
638     Tensor tensor(-0.0);
639     tensor.AsProtoField(&tensor_proto);
640     ASSERT_EQ(tensor_proto.double_val(0), -0.0);
641     ASSERT_TRUE(std::signbit(tensor_proto.double_val(0)));
642     tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
643     ASSERT_EQ(tensor_proto.double_val(0), -0.0);
644     ASSERT_TRUE(std::signbit(tensor_proto.double_val(0)));
645   }
646   {
647     // Float
648     Tensor tensor(-0.0f);
649     tensor.AsProtoField(&tensor_proto);
650     ASSERT_EQ(tensor_proto.float_val(0), -0.0);
651     ASSERT_TRUE(std::signbit(tensor_proto.float_val(0)));
652     tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
653     ASSERT_EQ(tensor_proto.float_val(0), -0.0);
654     ASSERT_TRUE(std::signbit(tensor_proto.float_val(0)));
655   }
656   {
657     // Half
658     Tensor tensor(Eigen::half(-0.0f));
659     tensor.AsProtoField(&tensor_proto);
660     ASSERT_EQ(tensor_proto.half_val(0), 0x8000);
661     tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
662     ASSERT_TRUE(tensor.FromProto(tensor_proto));
663     ASSERT_EQ(tensor.scalar<Eigen::half>()(), static_cast<Eigen::half>(0.0f));
664     ASSERT_TRUE(
665         std::signbit(static_cast<float>(tensor.scalar<Eigen::half>()())));
666   }
667   {
668     // Double Complex -0.0, -0.0
669     Tensor tensor(std::complex<double>(-0.0, -0.0));
670     tensor.AsProtoField(&tensor_proto);
671     ASSERT_EQ(tensor_proto.dcomplex_val(0), -0.0);
672     ASSERT_EQ(tensor_proto.dcomplex_val(1), -0.0);
673     tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
674     ASSERT_TRUE(tensor.FromProto(tensor_proto));
675     auto value = tensor.scalar<std::complex<double>>()();
676     ASSERT_EQ(value.real(), -0.0f);
677     ASSERT_TRUE(std::signbit(value.real()));
678     ASSERT_EQ(value.imag(), -0.0f);
679     ASSERT_TRUE(std::signbit(value.imag()));
680   }
681   {
682     // Double Complex 0.0, -0.0
683     Tensor tensor(std::complex<double>(0.0, -0.0));
684     tensor.AsProtoField(&tensor_proto);
685     ASSERT_EQ(tensor_proto.dcomplex_val(0), 0.0);
686     ASSERT_EQ(tensor_proto.dcomplex_val(1), -0.0);
687     tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
688     ASSERT_TRUE(tensor.FromProto(tensor_proto));
689     auto value = tensor.scalar<std::complex<double>>()();
690     ASSERT_EQ(value.real(), 0.0f);
691     ASSERT_FALSE(std::signbit(value.real()));
692     ASSERT_EQ(value.imag(), -0.0f);
693     ASSERT_TRUE(std::signbit(value.imag()));
694   }
695   {
696     // Double Complex -0.0, 0.0
697     Tensor tensor(std::complex<double>(-0.0, 0.0));
698     tensor.AsProtoField(&tensor_proto);
699     ASSERT_EQ(tensor_proto.dcomplex_val(0), -0.0);
700     ASSERT_EQ(tensor_proto.dcomplex_val(1), 0.0);
701     tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
702     ASSERT_TRUE(tensor.FromProto(tensor_proto));
703     auto value = tensor.scalar<std::complex<double>>()();
704     ASSERT_EQ(value.real(), -0.0f);
705     ASSERT_TRUE(std::signbit(value.real()));
706     ASSERT_EQ(value.imag(), 0.0f);
707     ASSERT_FALSE(std::signbit(value.imag()));
708   }
709 }
710 
711 }  // namespace
712 }  // namespace tensorflow
713