1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/type_traits.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/tensor_coding.h"
28 #include "tensorflow/core/platform/types.h"
29
30 namespace tensorflow {
31 namespace tensor {
32
DeepCopy(const Tensor & other)33 Tensor DeepCopy(const Tensor& other) {
34 Tensor tmp = Tensor(other.dtype(), other.shape());
35 DeepCopy(other, &tmp);
36 return tmp;
37 }
38
DeepCopy(const Tensor & input,Tensor * output)39 void DeepCopy(const Tensor& input, Tensor* output) {
40 if (DataTypeCanUseMemcpy(input.dtype())) {
41 if (input.NumElements() > 0) {
42 StringPiece input_data = input.tensor_data();
43
44 // We use StringPiece as a convenient map over the tensor buffer,
45 // but we cast the type to get to the underlying buffer to do the
46 // copy.
47 StringPiece output_data = output->tensor_data();
48 memcpy(const_cast<char*>(output_data.data()), input_data.data(),
49 input_data.size());
50 }
51 } else if (input.dtype() == DT_STRING) {
52 output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
53 } else {
54 CHECK_EQ(DT_VARIANT, input.dtype());
55 output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
56 }
57 }
58
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)59 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
60 if (tensors.empty()) {
61 return errors::InvalidArgument("Cannot concatenate zero tensors");
62 }
63 int64_t total_dim0_size = 0;
64 for (const Tensor& tensor : tensors) {
65 if (tensor.dims() == 0) {
66 return errors::InvalidArgument(
67 "Cannot concatenate a zero-dimensional tensor");
68 }
69 total_dim0_size += tensor.dim_size(0);
70 }
71 TensorShape shape = tensors[0].shape();
72 shape.set_dim(0, total_dim0_size);
73
74 const DataType dtype = tensors[0].dtype();
75 for (int i = 1; i < tensors.size(); ++i) {
76 if (tensors[i].dtype() != dtype) {
77 return errors::InvalidArgument(
78 "Cannot concatenate tensors that have different data types.", " Got ",
79 DataTypeString(dtype), " and ", DataTypeString(tensors[i].dtype()),
80 ".");
81 }
82 }
83 *result = Tensor(dtype, shape);
84
85 // We use StringPiece as a convenient map over the tensor buffer,
86 // but we cast the type to get to the underlying buffer to do the
87 // copy.
88 StringPiece to_data = result->tensor_data();
89
90 if (DataTypeCanUseMemcpy(dtype)) {
91 int64_t offset = 0;
92 for (const Tensor& tensor : tensors) {
93 StringPiece from_data = tensor.tensor_data();
94 CHECK_LE(offset + from_data.size(), to_data.size());
95 memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
96 from_data.size());
97
98 offset += from_data.size();
99 }
100 } else {
101 if (dtype != DT_STRING) {
102 return errors::Internal("Unexpected data type");
103 }
104 tstring* to_strings =
105 reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
106
107 int64_t offset = 0;
108 for (const Tensor& tensor : tensors) {
109 auto from_strings = tensor.flat<tstring>();
110 CHECK_LE(offset + tensor.NumElements(), result->NumElements());
111 for (int i = 0; i < tensor.NumElements(); ++i) {
112 to_strings[offset + i] = from_strings(i);
113 }
114
115 offset += tensor.NumElements();
116 }
117 }
118
119 return OkStatus();
120 }
121
Split(const Tensor & tensor,const gtl::ArraySlice<int64_t> & sizes,std::vector<Tensor> * result)122 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes,
123 std::vector<Tensor>* result) {
124 if (tensor.dims() == 0) {
125 return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
126 }
127 int64_t total_size = 0;
128 for (int64_t size : sizes) {
129 total_size += size;
130 }
131 if (total_size != tensor.dim_size(0)) {
132 return errors::InvalidArgument(
133 "The values in 'sizes' do not sum to the zeroth-dimension size of "
134 "'tensor'");
135 }
136
137 StringPiece from_data = tensor.tensor_data();
138
139 if (DataTypeCanUseMemcpy(tensor.dtype())) {
140 int64_t offset = 0;
141 for (int64_t size : sizes) {
142 TensorShape shape = tensor.shape();
143 shape.set_dim(0, size);
144 result->emplace_back(tensor.dtype(), shape);
145 Tensor* split = &(*result)[result->size() - 1];
146
147 // We use StringPiece as a convenient map over the tensor buffer,
148 // but we cast the type to get to the underlying buffer to do the
149 // copy.
150 StringPiece to_data = split->tensor_data();
151 CHECK_LE(offset + to_data.size(), from_data.size());
152 memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
153 to_data.size());
154
155 offset += to_data.size();
156 }
157 } else {
158 if (tensor.dtype() != DT_STRING) {
159 return errors::Internal("Unexpected data type");
160 }
161 auto from_strings = tensor.flat<tstring>();
162
163 int64_t offset = 0;
164 for (int64_t size : sizes) {
165 TensorShape shape = tensor.shape();
166 shape.set_dim(0, size);
167 result->emplace_back(tensor.dtype(), shape);
168 Tensor& split = (*result)[result->size() - 1];
169 tstring* to_strings = reinterpret_cast<tstring*>(
170 const_cast<char*>(split.tensor_data().data()));
171
172 CHECK_LE(offset + split.NumElements(), tensor.NumElements());
173 for (int i = 0; i < split.NumElements(); ++i) {
174 to_strings[i] = from_strings(offset + i);
175 }
176
177 offset += split.NumElements();
178 }
179 }
180
181 return OkStatus();
182 }
183
184 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)185 void SetTensorProtoShape(std::vector<size_t> shape,
186 TensorShapeProto* shape_proto) {
187 for (auto dim : shape) {
188 shape_proto->mutable_dim()->Add()->set_size(dim);
189 }
190 }
191
192 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)193 bool CompressTensorContent(float min_compression_ratio,
194 const TensorShape& shape, TensorProto* tensor) {
195 using TypeHelper = internal::TensorProtoHelper<T>;
196 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
197 const int64_t num_tensor_values = shape.num_elements();
198 const int64_t num_bytes = tensor->tensor_content().size();
199 const int64_t num_raw_values = num_bytes / sizeof(T);
200 if (num_raw_values != num_tensor_values) {
201 // Invalid or too small.
202 return false;
203 }
204 int64_t last_offset = num_bytes - 1;
205 int64_t prev_offset = last_offset - sizeof(T);
206 // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
207 // starting from the end, to find the last pair of elements that are not
208 // identical.
209 while (prev_offset >= 0) {
210 if (tensor->tensor_content()[prev_offset] !=
211 tensor->tensor_content()[last_offset]) {
212 break;
213 }
214 --last_offset;
215 --prev_offset;
216 }
217 if (prev_offset == -1) {
218 // It this is a splat of value 0, it does not need an explicit value, just
219 // erase the content.
220 T splat_value;
221 port::CopySubrangeToArray(tensor->tensor_content(), 0, sizeof(T),
222 reinterpret_cast<char*>(&splat_value));
223 if (splat_value == T(0)) {
224 tensor->clear_tensor_content();
225 return true;
226 }
227 }
228 // Round up to the next whole number of element of type T.
229 const int64_t new_num_values = last_offset / sizeof(T) + 1;
230 if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
231 static_cast<int64_t>(num_bytes / min_compression_ratio)) {
232 return false;
233 }
234 // Copy values to truncated repeated field.
235 if (sizeof(FieldType) == sizeof(T)) {
236 FieldType* dst_ptr =
237 TypeHelper::AppendUninitialized(new_num_values, tensor);
238 port::CopySubrangeToArray(tensor->tensor_content(), 0,
239 new_num_values * sizeof(T),
240 reinterpret_cast<char*>(dst_ptr));
241 tensor->clear_tensor_content();
242 } else if (sizeof(T) > 1) {
243 // Copy raw bytes to temp array first, then cast.
244 gtl::InlinedVector<T, 64> tmp(new_num_values);
245 port::CopySubrangeToArray(tensor->tensor_content(), 0,
246 new_num_values * sizeof(T),
247 reinterpret_cast<char*>(tmp.data()));
248 tensor->clear_tensor_content();
249 const T* begin = tmp.begin();
250 const T* end = tmp.end();
251 TypeHelper::AddValues(begin, end, tensor);
252 } else {
253 // Copy and cast, one byte at a time.
254 for (int64_t i = 0; i < new_num_values; ++i) {
255 char c = tensor->tensor_content()[i];
256 TypeHelper::AddValue(static_cast<T>(c), tensor);
257 }
258 tensor->clear_tensor_content();
259 }
260 return true;
261 }
262
263 template <typename T>
PackedValuesNotEqual(T a,T b)264 inline bool PackedValuesNotEqual(T a, T b) {
265 return a != b;
266 }
267 template <>
PackedValuesNotEqual(float a,float b)268 inline bool PackedValuesNotEqual(float a, float b) {
269 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
270 }
271 template <>
PackedValuesNotEqual(double a,double b)272 inline bool PackedValuesNotEqual(double a, double b) {
273 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
274 }
275 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)276 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
277 const std::complex<RealType>& b) {
278 return PackedValuesNotEqual(a.real(), b.real()) ||
279 PackedValuesNotEqual(a.imag(), b.imag());
280 }
281
282 // Integer can't be negative zero.
283 template <typename T,
284 typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
IsNegativeZero(T value)285 static bool IsNegativeZero(T value) {
286 return false;
287 }
288
289 template <typename T,
290 typename std::enable_if<!std::is_integral<T>::value>::type* = nullptr>
IsNegativeZero(T value)291 static bool IsNegativeZero(T value) {
292 return value == T(0) && std::signbit(value);
293 }
294
295 template <typename T>
IsNegativeZero(std::complex<T> value)296 static bool IsNegativeZero(std::complex<T> value) {
297 return IsNegativeZero(value.real()) || IsNegativeZero(value.imag());
298 }
299
IsNegativeZero(Eigen::QUInt8 value)300 static bool IsNegativeZero(Eigen::QUInt8 value) { return false; }
IsNegativeZero(Eigen::QInt8 value)301 static bool IsNegativeZero(Eigen::QInt8 value) { return false; }
IsNegativeZero(Eigen::QUInt16 value)302 static bool IsNegativeZero(Eigen::QUInt16 value) { return false; }
IsNegativeZero(Eigen::QInt16 value)303 static bool IsNegativeZero(Eigen::QInt16 value) { return false; }
IsNegativeZero(Eigen::QInt32 value)304 static bool IsNegativeZero(Eigen::QInt32 value) { return false; }
IsNegativeZero(Eigen::half value)305 static bool IsNegativeZero(Eigen::half value) {
306 return IsNegativeZero<float>(value);
307 }
IsNegativeZero(tensorflow::bfloat16 value)308 static bool IsNegativeZero(tensorflow::bfloat16 value) {
309 return IsNegativeZero<float>(float(value));
310 }
311
312 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)313 bool CompressRepeatedField(float min_compression_ratio,
314 const TensorShape& shape, TensorProto* tensor) {
315 using TypeHelper = internal::TensorProtoHelper<T>;
316 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
317 const int64_t num_tensor_values = shape.num_elements();
318 const int64_t num_proto_values = TypeHelper::NumValues(*tensor);
319
320 // Notice that for complex types the tensor is stored as an array of up to
321 // 2 * num_tensor_values real values (real and imaginary parts), possibly
322 // truncated. A 0-splat does not need any value present and is maximally
323 // compressed.
324 if (num_proto_values == 0) return false;
325
326 const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
327 int64_t last_index = 0;
328 for (int64_t i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
329 const T cur_value = TypeHelper::GetValue(i, *tensor);
330 if (PackedValuesNotEqual(cur_value, last_value)) {
331 last_index = i + 1;
332 }
333 }
334
335 // Detect all zeroes tensors: this is default value and the content can be
336 // erased entirely.
337 if (last_index == 0 && last_value == T(0) && !IsNegativeZero(last_value)) {
338 TypeHelper::Truncate(0, tensor);
339 return true;
340 }
341
342 const int64_t num_truncated_proto_values = last_index + 1;
343 const int64_t num_bytes_as_field =
344 num_truncated_proto_values * sizeof(FieldType);
345 const int64_t num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
346 const int64_t num_bytes_before = num_proto_values * sizeof(FieldType);
347 if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
348 static_cast<int64_t>(num_bytes_before / min_compression_ratio)) {
349 return false;
350 }
351 if (num_bytes_as_field <= num_bytes_as_tensor_content) {
352 TypeHelper::Truncate(num_truncated_proto_values, tensor);
353 } else {
354 gtl::InlinedVector<T, 64> tmp;
355 if (num_proto_values == 1) {
356 // Splat case.
357 tmp.resize(num_tensor_values, last_value);
358 } else {
359 tmp.resize(num_tensor_values, T(0));
360 TypeHelper::CopyValues(tmp.begin(), *tensor);
361 }
362 TypeHelper::Truncate(0, tensor);
363 port::CopyFromArray(tensor->mutable_tensor_content(),
364 reinterpret_cast<const char*>(tmp.data()),
365 num_bytes_as_tensor_content);
366 }
367 return true;
368 }
369
370 template <typename T>
CompressTensorProtoInPlaceImpl(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)371 bool CompressTensorProtoInPlaceImpl(int64_t min_num_elements,
372 float min_compression_ratio,
373 TensorProto* tensor) {
374 const TensorShape shape(tensor->tensor_shape());
375 const int64_t num_tensor_values = shape.num_elements();
376 if (num_tensor_values < min_num_elements) {
377 return false;
378 }
379 if (tensor->tensor_content().empty()) {
380 return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
381 } else {
382 return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
383 }
384 return true;
385 }
386
387 } // namespace internal
388
389 #define HANDLE_COMPRESS_CASE(TF_TYPE) \
390 case TF_TYPE: \
391 return internal::CompressTensorProtoInPlaceImpl< \
392 EnumToDataType<TF_TYPE>::Type>(min_num_elements, \
393 min_compression_ratio, tensor); \
394 break
395
CompressTensorProtoInPlace(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)396 bool CompressTensorProtoInPlace(int64_t min_num_elements,
397 float min_compression_ratio,
398 TensorProto* tensor) {
399 switch (tensor->dtype()) {
400 HANDLE_COMPRESS_CASE(DT_FLOAT);
401 HANDLE_COMPRESS_CASE(DT_DOUBLE);
402 HANDLE_COMPRESS_CASE(DT_COMPLEX64);
403 HANDLE_COMPRESS_CASE(DT_COMPLEX128);
404 HANDLE_COMPRESS_CASE(DT_UINT8);
405 HANDLE_COMPRESS_CASE(DT_INT8);
406 HANDLE_COMPRESS_CASE(DT_UINT16);
407 HANDLE_COMPRESS_CASE(DT_INT16);
408 HANDLE_COMPRESS_CASE(DT_UINT32);
409 HANDLE_COMPRESS_CASE(DT_INT32);
410 HANDLE_COMPRESS_CASE(DT_UINT64);
411 HANDLE_COMPRESS_CASE(DT_INT64);
412 HANDLE_COMPRESS_CASE(DT_BOOL);
413 HANDLE_COMPRESS_CASE(DT_QUINT8);
414 HANDLE_COMPRESS_CASE(DT_QINT8);
415 HANDLE_COMPRESS_CASE(DT_QUINT16);
416 HANDLE_COMPRESS_CASE(DT_QINT16);
417 HANDLE_COMPRESS_CASE(DT_QINT32);
418 HANDLE_COMPRESS_CASE(DT_HALF);
419 HANDLE_COMPRESS_CASE(DT_BFLOAT16);
420 default:
421 return false;
422 }
423 }
424
425 #undef HANDLE_COMPRESS_CASE
426
MakeShape(const Tensor & shape,TensorShape * out)427 Status MakeShape(const Tensor& shape, TensorShape* out) {
428 if (!TensorShapeUtils::IsVector(shape.shape())) {
429 return errors::InvalidArgument(
430 "shape must be a vector of {int32,int64}, got shape ",
431 shape.shape().DebugString());
432 }
433 if (shape.dtype() == DataType::DT_INT32) {
434 auto vec = shape.flat<int32>();
435 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
436 } else if (shape.dtype() == DataType::DT_INT64) {
437 auto vec = shape.flat<int64_t>();
438 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
439 } else {
440 return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
441 }
442 }
443
444 } // namespace tensor
445 } // namespace tensorflow
446