1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <vector>
19
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/tensor_testutil.h"
22 #include "tensorflow/core/framework/types.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/framework/variant_encode_decode.h"
25 #include "tensorflow/core/framework/variant_tensor_data.h"
26 #include "tensorflow/core/lib/core/status_test_util.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace {
31
TEST(TensorUtil,DeepCopy0d)32 TEST(TensorUtil, DeepCopy0d) {
33 Tensor x(DT_FLOAT, TensorShape({}));
34 x.scalar<float>()() = 10.0;
35
36 // Make y a deep copy of x and then change it.
37 Tensor y = tensor::DeepCopy(x);
38 y.scalar<float>()() = 20.0;
39
40 // x doesn't change
41 EXPECT_EQ(10.0, x.scalar<float>()());
42
43 // Change x.
44 x.scalar<float>()() = 30.0;
45
46 // Y doesn't change.
47 EXPECT_EQ(20.0, y.scalar<float>()());
48
49 Tensor z = tensor::DeepCopy(y);
50
51 // Change y.
52 y.scalar<float>()() = 40.0;
53
54 // The final states should all be different.
55 EXPECT_EQ(20.0, z.scalar<float>()());
56 EXPECT_EQ(30.0, x.scalar<float>()());
57 EXPECT_EQ(40.0, y.scalar<float>()());
58
59 // Should have the same shape and type.
60 EXPECT_EQ(TensorShape({}), x.shape());
61 EXPECT_EQ(TensorShape({}), y.shape());
62 EXPECT_EQ(TensorShape({}), z.shape());
63
64 EXPECT_EQ(DT_FLOAT, x.dtype());
65 EXPECT_EQ(DT_FLOAT, y.dtype());
66 EXPECT_EQ(DT_FLOAT, z.dtype());
67 }
68
TEST(TensorUtil,DeepCopyZeroElements)69 TEST(TensorUtil, DeepCopyZeroElements) {
70 Tensor x;
71 Tensor y = tensor::DeepCopy(x);
72 EXPECT_EQ(TensorShape({0}), y.shape());
73 EXPECT_EQ(DT_FLOAT, y.dtype());
74 EXPECT_EQ(0, y.NumElements());
75 }
76
TEST(TensorUtil,DeepCopy)77 TEST(TensorUtil, DeepCopy) {
78 Tensor x(DT_FLOAT, TensorShape({1}));
79 x.flat<float>()(0) = 10.0;
80
81 // Make y a deep copy of x and then change it.
82 Tensor y = tensor::DeepCopy(x);
83 y.flat<float>()(0) = 20.0;
84
85 // x doesn't change
86 EXPECT_EQ(10.0, x.flat<float>()(0));
87
88 // Change x.
89 x.flat<float>()(0) = 30.0;
90
91 // Y doesn't change.
92 EXPECT_EQ(20.0, y.flat<float>()(0));
93
94 Tensor z = tensor::DeepCopy(y);
95
96 // Change y.
97 y.flat<float>()(0) = 40.0;
98
99 // The final states should all be different.
100 EXPECT_EQ(20.0, z.flat<float>()(0));
101 EXPECT_EQ(30.0, x.flat<float>()(0));
102 EXPECT_EQ(40.0, y.flat<float>()(0));
103
104 // Should have the same shape and type.
105 EXPECT_EQ(TensorShape({1}), x.shape());
106 EXPECT_EQ(TensorShape({1}), y.shape());
107 EXPECT_EQ(TensorShape({1}), z.shape());
108
109 EXPECT_EQ(DT_FLOAT, x.dtype());
110 EXPECT_EQ(DT_FLOAT, y.dtype());
111 EXPECT_EQ(DT_FLOAT, z.dtype());
112
113 // Test string deep copy
114 Tensor str1(DT_STRING, TensorShape({2}));
115 str1.flat<tstring>()(0) = "foo1";
116 str1.flat<tstring>()(1) = "foo2";
117 Tensor str2 = tensor::DeepCopy(str1);
118 str2.flat<tstring>()(0) = "bar1";
119 str2.flat<tstring>()(1) = "bar2";
120 EXPECT_NE(str2.flat<tstring>()(0), str1.flat<tstring>()(0));
121 }
122
TEST(TensorUtil,DeepCopySlice)123 TEST(TensorUtil, DeepCopySlice) {
124 Tensor x(DT_INT32, TensorShape({10}));
125 x.flat<int32>().setConstant(1);
126
127 // Slice 'x' -- y still refers to the same buffer.
128 Tensor y = x.Slice(2, 6);
129
130 // Do a deep copy of y, which is a slice.
131 Tensor z = tensor::DeepCopy(y);
132
133 // Set x to be different.
134 x.flat<int32>().setConstant(2);
135
136 EXPECT_EQ(TensorShape({10}), x.shape());
137 EXPECT_EQ(TensorShape({4}), y.shape());
138 EXPECT_EQ(TensorShape({4}), z.shape());
139 EXPECT_EQ(DT_INT32, x.dtype());
140 EXPECT_EQ(DT_INT32, y.dtype());
141 EXPECT_EQ(DT_INT32, z.dtype());
142
143 // x and y should now all be '2', but z should be '1'.
144 for (int i = 0; i < 10; ++i) {
145 EXPECT_EQ(2, x.flat<int32>()(i));
146 }
147 for (int i = 0; i < 4; ++i) {
148 EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
149 EXPECT_EQ(1, z.flat<int32>()(i));
150 }
151 }
152
TEST(TensorUtil,DeepCopySliceString)153 TEST(TensorUtil, DeepCopySliceString) {
154 Tensor x(DT_STRING, TensorShape({10}));
155 x.flat<tstring>().setConstant("hello");
156
157 // Slice 'x' -- y still refers to the same buffer.
158 Tensor y = x.Slice(3, 7);
159
160 // Do a deep copy of y, which is a slice.
161 Tensor z = tensor::DeepCopy(y);
162
163 // Set x to be different.
164 x.flat<tstring>().setConstant("goodbye");
165
166 EXPECT_EQ(TensorShape({10}), x.shape());
167 EXPECT_EQ(TensorShape({4}), y.shape());
168 EXPECT_EQ(TensorShape({4}), z.shape());
169 EXPECT_EQ(DT_STRING, x.dtype());
170 EXPECT_EQ(DT_STRING, y.dtype());
171 EXPECT_EQ(DT_STRING, z.dtype());
172
173 // x and y should now all be 'goodbye', but z should be 'hello'.
174 for (int i = 0; i < 10; ++i) {
175 EXPECT_EQ("goodbye", x.flat<tstring>()(i));
176 }
177 for (int i = 0; i < 4; ++i) {
178 EXPECT_EQ("goodbye", y.unaligned_flat<tstring>()(i));
179 EXPECT_EQ("hello", z.flat<tstring>()(i));
180 }
181 }
182
TEST(TensorUtil,DeepCopySliceVariant)183 TEST(TensorUtil, DeepCopySliceVariant) {
184 Tensor x(DT_VARIANT, TensorShape({10}));
185 x.flat<Variant>().setConstant(Tensor(42.0f));
186
187 // Slice 'x' -- y still refers to the same buffer.
188 Tensor y = x.Slice(3, 7);
189
190 // Do a deep copy of y, which is a slice.
191 Tensor z = tensor::DeepCopy(y);
192
193 // Set x to be different.
194 x.flat<Variant>().setConstant(Tensor("foo"));
195
196 EXPECT_EQ(TensorShape({10}), x.shape());
197 EXPECT_EQ(TensorShape({4}), y.shape());
198 EXPECT_EQ(TensorShape({4}), z.shape());
199 EXPECT_EQ(DT_VARIANT, x.dtype());
200 EXPECT_EQ(DT_VARIANT, y.dtype());
201 EXPECT_EQ(DT_VARIANT, z.dtype());
202
203 // Each element of x and y should now be a DT_STRING Tensor containing "foo",
204 // but each element of z should be a DT_FLOAT tensor containing 42.0.
205 for (int i = 0; i < 10; ++i) {
206 EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
207 }
208 for (int i = 0; i < 4; ++i) {
209 EXPECT_EQ(
210 "foo",
211 y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<tstring>()());
212 EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
213 }
214 }
215
TEST(TensorUtil,Concat)216 TEST(TensorUtil, Concat) {
217 std::vector<int64_t> sizes = {1, 4, 5};
218 std::vector<Tensor> to_concat;
219 int64_t total_size = 0;
220 int offset = 0;
221 for (size_t entry = 0; entry < sizes.size(); ++entry) {
222 const int64_t size = sizes[entry];
223 Tensor tensor(DT_INT32, TensorShape({size, 2}));
224 for (int i = offset; i < offset + size; ++i) {
225 for (int j = 0; j < 2; ++j) {
226 tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
227 }
228 }
229 to_concat.push_back(tensor);
230 total_size += size;
231 offset += size;
232 }
233
234 Tensor concated;
235 TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
236 ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
237 for (int i = 0; i < total_size; ++i) {
238 for (int j = 0; j < 2; ++j) {
239 EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
240 }
241 }
242 }
243
TEST(TensorUtil,Split)244 TEST(TensorUtil, Split) {
245 Tensor to_split(DT_INT64, TensorShape({10, 2}));
246 for (int i = 0; i < 10; ++i) {
247 for (int j = 0; j < 2; ++j) {
248 to_split.matrix<int64_t>()(i, j) = 2 * i + j;
249 }
250 }
251
252 std::vector<int64_t> sizes = {1, 4, 5};
253 std::vector<Tensor> splits;
254 TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
255 ASSERT_EQ(sizes.size(), splits.size());
256
257 int offset = 0;
258 for (size_t entry = 0; entry < splits.size(); ++entry) {
259 const int64_t size = sizes[entry];
260 const Tensor& split = splits[entry];
261
262 ASSERT_EQ(TensorShape({size, 2}), split.shape());
263 for (int i = offset; i < offset + size; ++i) {
264 for (int j = 0; j < 2; ++j) {
265 EXPECT_EQ(2 * i + j, split.matrix<int64_t>()(i - offset, j));
266 }
267 }
268
269 offset += size;
270 }
271 }
272
TEST(TensorUtil,ConcatSplitStrings)273 TEST(TensorUtil, ConcatSplitStrings) {
274 Tensor x(DT_STRING, TensorShape({4, 3}));
275 for (int i = 0; i < 4 * 3; ++i) {
276 x.flat<tstring>()(i) = strings::StrCat("foo_", i);
277 }
278
279 std::vector<Tensor> split;
280 TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
281 Tensor x_round_tripped;
282 TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
283 ASSERT_EQ(x.shape(), x_round_tripped.shape());
284 for (int i = 0; i < 4 * 3; ++i) {
285 EXPECT_EQ(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
286 }
287
288 // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
289 for (int i = 0; i < 4 * 3; ++i) {
290 x_round_tripped.flat<tstring>()(i) = strings::StrCat("bar_", i);
291 }
292 for (int i = 0; i < 4 * 3; ++i) {
293 EXPECT_NE(x.flat<tstring>()(i), x_round_tripped.flat<tstring>()(i));
294 }
295 }
296
TEST(TensorProtoUtil,CreatesStringTensorProto)297 TEST(TensorProtoUtil, CreatesStringTensorProto) {
298 std::vector<string> values{"a", "b", "c"};
299 std::vector<size_t> shape{1, 3};
300
301 auto proto = tensor::CreateTensorProto(values, shape);
302
303 TensorProto expected_tensor_proto;
304 protobuf::TextFormat::ParseFromString(
305 "dtype: DT_STRING\n"
306 "tensor_shape {\n"
307 " dim {\n"
308 " size: 1\n"
309 " }\n"
310 " dim {\n"
311 " size: 3\n"
312 " }\n"
313 "}\n"
314 "string_val: \"a\"\n"
315 "string_val: \"b\"\n"
316 "string_val: \"c\"\n",
317 &expected_tensor_proto);
318 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
319 }
320
TEST(TensorProtoUtil,CreatesInt32TensorProto)321 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
322 std::vector<int32> values{1, 2};
323 std::vector<size_t> shape{2};
324
325 auto proto = tensor::CreateTensorProto(values, shape);
326
327 TensorProto expected_tensor_proto;
328 protobuf::TextFormat::ParseFromString(
329 "dtype: DT_INT32\n"
330 "tensor_shape {\n"
331 " dim {\n"
332 " size: 2\n"
333 " }\n"
334 "}\n"
335 "int_val: 1\n"
336 "int_val: 2\n",
337 &expected_tensor_proto);
338 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
339 }
340
TEST(TensorProtoUtil,CreatesInt64TensorProto)341 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
342 std::vector<int64_t> values{1, 2};
343 std::vector<size_t> shape{2};
344
345 auto proto = tensor::CreateTensorProto(values, shape);
346
347 TensorProto expected_tensor_proto;
348 protobuf::TextFormat::ParseFromString(
349 "dtype: DT_INT64\n"
350 "tensor_shape {\n"
351 " dim {\n"
352 " size: 2\n"
353 " }\n"
354 "}\n"
355 "int64_val: 1\n"
356 "int64_val: 2\n",
357 &expected_tensor_proto);
358 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
359 }
360
TEST(TensorProtoUtil,CreatesUInt32TensorProto)361 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
362 std::vector<uint32> values{1, 2};
363 std::vector<size_t> shape{2};
364
365 auto proto = tensor::CreateTensorProto(values, shape);
366
367 TensorProto expected_tensor_proto;
368 protobuf::TextFormat::ParseFromString(
369 "dtype: DT_UINT32\n"
370 "tensor_shape {\n"
371 " dim {\n"
372 " size: 2\n"
373 " }\n"
374 "}\n"
375 "uint32_val: 1\n"
376 "uint32_val: 2\n",
377 &expected_tensor_proto);
378 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
379 }
380
TEST(TensorProtoUtil,CreatesUInt64TensorProto)381 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
382 std::vector<uint64> values{1, 2};
383 std::vector<size_t> shape{2};
384
385 auto proto = tensor::CreateTensorProto(values, shape);
386
387 TensorProto expected_tensor_proto;
388 protobuf::TextFormat::ParseFromString(
389 "dtype: DT_UINT64\n"
390 "tensor_shape {\n"
391 " dim {\n"
392 " size: 2\n"
393 " }\n"
394 "}\n"
395 "uint64_val: 1\n"
396 "uint64_val: 2\n",
397 &expected_tensor_proto);
398 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
399 }
400
TEST(TensorProtoUtil,CreatesFloatTensorProto)401 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
402 std::vector<float> values{1.1, 2.2};
403 std::vector<size_t> shape{2};
404
405 auto proto = tensor::CreateTensorProto(values, shape);
406
407 TensorProto expected_tensor_proto;
408 protobuf::TextFormat::ParseFromString(
409 "dtype: DT_FLOAT\n"
410 "tensor_shape {\n"
411 " dim {\n"
412 " size: 2\n"
413 " }\n"
414 "}\n"
415 "float_val: 1.1\n"
416 "float_val: 2.2\n",
417 &expected_tensor_proto);
418 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
419 }
420
TEST(TensorProtoUtil,CreatesDoubleTensorProto)421 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
422 std::vector<double> values{1.1, 2.2};
423 std::vector<size_t> shape{2};
424
425 auto proto = tensor::CreateTensorProto(values, shape);
426
427 TensorProto expected_tensor_proto;
428 protobuf::TextFormat::ParseFromString(
429 "dtype: DT_DOUBLE\n"
430 "tensor_shape {\n"
431 " dim {\n"
432 " size: 2\n"
433 " }\n"
434 "}\n"
435 "double_val: 1.1\n"
436 "double_val: 2.2\n",
437 &expected_tensor_proto);
438 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
439 }
440
TEST(TensorProtoUtil,CreatesBoolTensorProto)441 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
442 std::vector<bool> values{true, false};
443 std::vector<size_t> shape{2};
444
445 auto proto = tensor::CreateTensorProto(values, shape);
446
447 TensorProto expected_tensor_proto;
448 protobuf::TextFormat::ParseFromString(
449 "dtype: DT_BOOL\n"
450 "tensor_shape {\n"
451 " dim {\n"
452 " size: 2\n"
453 " }\n"
454 "}\n"
455 "bool_val: true\n"
456 "bool_val: false\n",
457 &expected_tensor_proto);
458 EXPECT_EQ(proto.DebugString(), expected_tensor_proto.DebugString());
459 }
460
TEST(TensorProtoUtil,CompressTensorProtoInPlaceTooSmall)461 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
462 const int kLength = 63;
463 TensorProto tensor_proto =
464 tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
465 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
466 tensor_proto =
467 tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
468 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
469 tensor_proto =
470 tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
471 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
472 tensor_proto =
473 tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
474 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
475 tensor_proto =
476 tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
477 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
478 tensor_proto = tensor::CreateTensorProto(
479 std::vector<std::complex<float>>(kLength), {kLength});
480 EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
481 }
482
TEST(TensorProtoUtil,CompressTensorProtoInPlaceAllEqual)483 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
484 const int kLength = 64;
485 TensorProto tensor_proto =
486 tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
487 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
488 EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
489 0);
490
491 tensor_proto =
492 tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
493 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
494 EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
495 0);
496
497 tensor_proto =
498 tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
499 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
500 EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
501 0);
502 tensor_proto =
503 tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
504 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
505 EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
506 0);
507
508 tensor_proto =
509 tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
510 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
511 EXPECT_EQ(
512 tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
513 0);
514
515 tensor_proto = tensor::CreateTensorProto(
516 std::vector<std::complex<float>>(kLength), {kLength});
517 EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
518 EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
519 tensor_proto),
520 0);
521 }
522
523 template <typename T>
VectorWithConstantTail(int size,int tail_length,std::vector<T> * v)524 void VectorWithConstantTail(int size, int tail_length, std::vector<T>* v) {
525 CHECK_LE(tail_length, size);
526 v->clear();
527 for (int i = 0; i < size; ++i) {
528 T vi = (i >= size - tail_length) ? T() : T(i);
529 v->push_back(vi);
530 }
531 }
532
533 template <>
VectorWithConstantTail(int size,int tail_length,std::vector<std::complex<float>> * v)534 void VectorWithConstantTail(int size, int tail_length,
535 std::vector<std::complex<float>>* v) {
536 CHECK_LE(tail_length, size);
537 v->clear();
538 for (int i = 0; i < size; ++i) {
539 std::complex<float> vi(
540 0.0f, (i >= (size - tail_length)) ? 0.f : static_cast<float>(i));
541 v->push_back(vi);
542 }
543 }
544
545 template <typename T>
CreateAsProtoTensorContent(int size,int tail_length)546 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
547 std::vector<T> values;
548 VectorWithConstantTail<T>(size, tail_length, &values);
549 Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
550 std::copy(values.begin(), values.end(), tensor.flat<T>().data());
551 TensorProto tensor_proto;
552 tensor.AsProtoTensorContent(&tensor_proto);
553 return tensor_proto;
554 }
555
556 template <typename T>
CreateAsProtoField(int size,int tail_length)557 TensorProto CreateAsProtoField(int size, int tail_length) {
558 std::vector<T> values;
559 VectorWithConstantTail<T>(size, tail_length, &values);
560 Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
561 std::copy(values.begin(), values.end(), tensor.flat<T>().data());
562 TensorProto tensor_proto;
563 tensor.AsProtoField(&tensor_proto);
564 return tensor_proto;
565 }
566
567 template <typename T>
CompareTensorValues(const TensorProto & x,const TensorProto & y)568 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
569 Tensor x_t;
570 EXPECT_TRUE(x_t.FromProto(x));
571 Tensor y_t;
572 EXPECT_TRUE(y_t.FromProto(y));
573 test::ExpectTensorEqual<T>(x_t, y_t);
574 }
575
576 template <typename T>
ConstantTailTest(int64_t length,int64_t tail_length,bool as_field)577 void ConstantTailTest(int64_t length, int64_t tail_length, bool as_field) {
578 using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
579 using FieldType = typename TensorProtoHelper::FieldType;
580 const float kMinCompressionRatio = 2.0;
581 const int64_t kMinSize = 64;
582 TensorProto tensor_proto =
583 as_field ? CreateAsProtoField<T>(length, tail_length)
584 : CreateAsProtoTensorContent<T>(length, tail_length);
585 TensorProto original_tensor_proto = tensor_proto;
586 int64_t original_size =
587 length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
588 : sizeof(T));
589 int64_t size_as_tensor_content = length * sizeof(T);
590 int64_t size_as_field = std::min(length, (length - tail_length + 1)) *
591 (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
592 bool will_compress =
593 std::min(size_as_tensor_content, size_as_field) <=
594 static_cast<int64_t>(original_size / kMinCompressionRatio);
595
596 EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
597 &tensor_proto),
598 will_compress);
599 if (will_compress) {
600 if (size_as_tensor_content < size_as_field) {
601 EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
602 EXPECT_FALSE(tensor_proto.tensor_content().empty());
603 } else {
604 EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
605 (length - tail_length + 1));
606 EXPECT_TRUE(tensor_proto.tensor_content().empty());
607 }
608 }
609 CompareTensorValues<T>(tensor_proto, original_tensor_proto);
610 }
611
TEST(TensorProtoUtil,CompressTensorProtoConstantTail)612 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
613 const int kLength = 64;
614 for (bool as_field : {true, false}) {
615 for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
616 ConstantTailTest<float>(kLength, tail_length, as_field);
617 ConstantTailTest<double>(kLength, tail_length, as_field);
618 ConstantTailTest<complex64>(kLength, tail_length, as_field);
619 ConstantTailTest<complex128>(kLength, tail_length, as_field);
620 ConstantTailTest<int32>(kLength, tail_length, as_field);
621 ConstantTailTest<uint32>(kLength, tail_length, as_field);
622 ConstantTailTest<int64_t>(kLength, tail_length, as_field);
623 ConstantTailTest<uint64>(kLength, tail_length, as_field);
624 ConstantTailTest<int8>(kLength, tail_length, as_field);
625 ConstantTailTest<uint8>(kLength, tail_length, as_field);
626 ConstantTailTest<int16>(kLength, tail_length, as_field);
627 ConstantTailTest<uint16>(kLength, tail_length, as_field);
628 ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
629 ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
630 }
631 }
632 }
633
TEST(TensorProtoUtil,CompressTensorProtoNegatizeZero)634 TEST(TensorProtoUtil, CompressTensorProtoNegatizeZero) {
635 TensorProto tensor_proto;
636 {
637 // Double
638 Tensor tensor(-0.0);
639 tensor.AsProtoField(&tensor_proto);
640 ASSERT_EQ(tensor_proto.double_val(0), -0.0);
641 ASSERT_TRUE(std::signbit(tensor_proto.double_val(0)));
642 tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
643 ASSERT_EQ(tensor_proto.double_val(0), -0.0);
644 ASSERT_TRUE(std::signbit(tensor_proto.double_val(0)));
645 }
646 {
647 // Float
648 Tensor tensor(-0.0f);
649 tensor.AsProtoField(&tensor_proto);
650 ASSERT_EQ(tensor_proto.float_val(0), -0.0);
651 ASSERT_TRUE(std::signbit(tensor_proto.float_val(0)));
652 tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
653 ASSERT_EQ(tensor_proto.float_val(0), -0.0);
654 ASSERT_TRUE(std::signbit(tensor_proto.float_val(0)));
655 }
656 {
657 // Half
658 Tensor tensor(Eigen::half(-0.0f));
659 tensor.AsProtoField(&tensor_proto);
660 ASSERT_EQ(tensor_proto.half_val(0), 0x8000);
661 tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
662 ASSERT_TRUE(tensor.FromProto(tensor_proto));
663 ASSERT_EQ(tensor.scalar<Eigen::half>()(), static_cast<Eigen::half>(0.0f));
664 ASSERT_TRUE(
665 std::signbit(static_cast<float>(tensor.scalar<Eigen::half>()())));
666 }
667 {
668 // Double Complex -0.0, -0.0
669 Tensor tensor(std::complex<double>(-0.0, -0.0));
670 tensor.AsProtoField(&tensor_proto);
671 ASSERT_EQ(tensor_proto.dcomplex_val(0), -0.0);
672 ASSERT_EQ(tensor_proto.dcomplex_val(1), -0.0);
673 tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
674 ASSERT_TRUE(tensor.FromProto(tensor_proto));
675 auto value = tensor.scalar<std::complex<double>>()();
676 ASSERT_EQ(value.real(), -0.0f);
677 ASSERT_TRUE(std::signbit(value.real()));
678 ASSERT_EQ(value.imag(), -0.0f);
679 ASSERT_TRUE(std::signbit(value.imag()));
680 }
681 {
682 // Double Complex 0.0, -0.0
683 Tensor tensor(std::complex<double>(0.0, -0.0));
684 tensor.AsProtoField(&tensor_proto);
685 ASSERT_EQ(tensor_proto.dcomplex_val(0), 0.0);
686 ASSERT_EQ(tensor_proto.dcomplex_val(1), -0.0);
687 tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
688 ASSERT_TRUE(tensor.FromProto(tensor_proto));
689 auto value = tensor.scalar<std::complex<double>>()();
690 ASSERT_EQ(value.real(), 0.0f);
691 ASSERT_FALSE(std::signbit(value.real()));
692 ASSERT_EQ(value.imag(), -0.0f);
693 ASSERT_TRUE(std::signbit(value.imag()));
694 }
695 {
696 // Double Complex -0.0, 0.0
697 Tensor tensor(std::complex<double>(-0.0, 0.0));
698 tensor.AsProtoField(&tensor_proto);
699 ASSERT_EQ(tensor_proto.dcomplex_val(0), -0.0);
700 ASSERT_EQ(tensor_proto.dcomplex_val(1), 0.0);
701 tensor::CompressTensorProtoInPlace(1, 1.0, &tensor_proto);
702 ASSERT_TRUE(tensor.FromProto(tensor_proto));
703 auto value = tensor.scalar<std::complex<double>>()();
704 ASSERT_EQ(value.real(), -0.0f);
705 ASSERT_TRUE(std::signbit(value.real()));
706 ASSERT_EQ(value.imag(), 0.0f);
707 ASSERT_FALSE(std::signbit(value.imag()));
708 }
709 }
710
711 } // namespace
712 } // namespace tensorflow
713