• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_util.h"
17 
18 #include <cmath>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/type_traits.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/tensor_coding.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 namespace tensor {
32 
DeepCopy(const Tensor & other)33 Tensor DeepCopy(const Tensor& other) {
34   Tensor tmp = Tensor(other.dtype(), other.shape());
35   DeepCopy(other, &tmp);
36   return tmp;
37 }
38 
DeepCopy(const Tensor & input,Tensor * output)39 void DeepCopy(const Tensor& input, Tensor* output) {
40   if (DataTypeCanUseMemcpy(input.dtype())) {
41     if (input.NumElements() > 0) {
42       StringPiece input_data = input.tensor_data();
43 
44       // We use StringPiece as a convenient map over the tensor buffer,
45       // but we cast the type to get to the underlying buffer to do the
46       // copy.
47       StringPiece output_data = output->tensor_data();
48       memcpy(const_cast<char*>(output_data.data()), input_data.data(),
49              input_data.size());
50     }
51   } else if (input.dtype() == DT_STRING) {
52     output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
53   } else {
54     CHECK_EQ(DT_VARIANT, input.dtype());
55     output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
56   }
57 }
58 
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)59 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
60   if (tensors.empty()) {
61     return errors::InvalidArgument("Cannot concatenate zero tensors");
62   }
63   int64_t total_dim0_size = 0;
64   for (const Tensor& tensor : tensors) {
65     if (tensor.dims() == 0) {
66       return errors::InvalidArgument(
67           "Cannot concatenate a zero-dimensional tensor");
68     }
69     total_dim0_size += tensor.dim_size(0);
70   }
71   TensorShape shape = tensors[0].shape();
72   shape.set_dim(0, total_dim0_size);
73 
74   const DataType dtype = tensors[0].dtype();
75   for (int i = 1; i < tensors.size(); ++i) {
76     if (tensors[i].dtype() != dtype) {
77       return errors::InvalidArgument(
78           "Cannot concatenate tensors that have different data types.", " Got ",
79           DataTypeString(dtype), " and ", DataTypeString(tensors[i].dtype()),
80           ".");
81     }
82   }
83   *result = Tensor(dtype, shape);
84 
85   // We use StringPiece as a convenient map over the tensor buffer,
86   // but we cast the type to get to the underlying buffer to do the
87   // copy.
88   StringPiece to_data = result->tensor_data();
89 
90   if (DataTypeCanUseMemcpy(dtype)) {
91     int64_t offset = 0;
92     for (const Tensor& tensor : tensors) {
93       StringPiece from_data = tensor.tensor_data();
94       CHECK_LE(offset + from_data.size(), to_data.size());
95       memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
96              from_data.size());
97 
98       offset += from_data.size();
99     }
100   } else {
101     if (dtype != DT_STRING) {
102       return errors::Internal("Unexpected data type");
103     }
104     tstring* to_strings =
105         reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
106 
107     int64_t offset = 0;
108     for (const Tensor& tensor : tensors) {
109       auto from_strings = tensor.flat<tstring>();
110       CHECK_LE(offset + tensor.NumElements(), result->NumElements());
111       for (int i = 0; i < tensor.NumElements(); ++i) {
112         to_strings[offset + i] = from_strings(i);
113       }
114 
115       offset += tensor.NumElements();
116     }
117   }
118 
119   return OkStatus();
120 }
121 
Split(const Tensor & tensor,const gtl::ArraySlice<int64_t> & sizes,std::vector<Tensor> * result)122 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes,
123              std::vector<Tensor>* result) {
124   if (tensor.dims() == 0) {
125     return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
126   }
127   int64_t total_size = 0;
128   for (int64_t size : sizes) {
129     total_size += size;
130   }
131   if (total_size != tensor.dim_size(0)) {
132     return errors::InvalidArgument(
133         "The values in 'sizes' do not sum to the zeroth-dimension size of "
134         "'tensor'");
135   }
136 
137   StringPiece from_data = tensor.tensor_data();
138 
139   if (DataTypeCanUseMemcpy(tensor.dtype())) {
140     int64_t offset = 0;
141     for (int64_t size : sizes) {
142       TensorShape shape = tensor.shape();
143       shape.set_dim(0, size);
144       result->emplace_back(tensor.dtype(), shape);
145       Tensor* split = &(*result)[result->size() - 1];
146 
147       // We use StringPiece as a convenient map over the tensor buffer,
148       // but we cast the type to get to the underlying buffer to do the
149       // copy.
150       StringPiece to_data = split->tensor_data();
151       CHECK_LE(offset + to_data.size(), from_data.size());
152       memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
153              to_data.size());
154 
155       offset += to_data.size();
156     }
157   } else {
158     if (tensor.dtype() != DT_STRING) {
159       return errors::Internal("Unexpected data type");
160     }
161     auto from_strings = tensor.flat<tstring>();
162 
163     int64_t offset = 0;
164     for (int64_t size : sizes) {
165       TensorShape shape = tensor.shape();
166       shape.set_dim(0, size);
167       result->emplace_back(tensor.dtype(), shape);
168       Tensor& split = (*result)[result->size() - 1];
169       tstring* to_strings = reinterpret_cast<tstring*>(
170           const_cast<char*>(split.tensor_data().data()));
171 
172       CHECK_LE(offset + split.NumElements(), tensor.NumElements());
173       for (int i = 0; i < split.NumElements(); ++i) {
174         to_strings[i] = from_strings(offset + i);
175       }
176 
177       offset += split.NumElements();
178     }
179   }
180 
181   return OkStatus();
182 }
183 
184 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)185 void SetTensorProtoShape(std::vector<size_t> shape,
186                          TensorShapeProto* shape_proto) {
187   for (auto dim : shape) {
188     shape_proto->mutable_dim()->Add()->set_size(dim);
189   }
190 }
191 
192 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)193 bool CompressTensorContent(float min_compression_ratio,
194                            const TensorShape& shape, TensorProto* tensor) {
195   using TypeHelper = internal::TensorProtoHelper<T>;
196   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
197   const int64_t num_tensor_values = shape.num_elements();
198   const int64_t num_bytes = tensor->tensor_content().size();
199   const int64_t num_raw_values = num_bytes / sizeof(T);
200   if (num_raw_values != num_tensor_values) {
201     // Invalid or too small.
202     return false;
203   }
204   int64_t last_offset = num_bytes - 1;
205   int64_t prev_offset = last_offset - sizeof(T);
206   // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
207   // starting from the end, to find the last pair of elements that are not
208   // identical.
209   while (prev_offset >= 0) {
210     if (tensor->tensor_content()[prev_offset] !=
211         tensor->tensor_content()[last_offset]) {
212       break;
213     }
214     --last_offset;
215     --prev_offset;
216   }
217   if (prev_offset == -1) {
218     // It this is a splat of value 0, it does not need an explicit value, just
219     // erase the content.
220     T splat_value;
221     port::CopySubrangeToArray(tensor->tensor_content(), 0, sizeof(T),
222                               reinterpret_cast<char*>(&splat_value));
223     if (splat_value == T(0)) {
224       tensor->clear_tensor_content();
225       return true;
226     }
227   }
228   // Round up to the next whole number of element of type T.
229   const int64_t new_num_values = last_offset / sizeof(T) + 1;
230   if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
231       static_cast<int64_t>(num_bytes / min_compression_ratio)) {
232     return false;
233   }
234   // Copy values to truncated repeated field.
235   if (sizeof(FieldType) == sizeof(T)) {
236     FieldType* dst_ptr =
237         TypeHelper::AppendUninitialized(new_num_values, tensor);
238     port::CopySubrangeToArray(tensor->tensor_content(), 0,
239                               new_num_values * sizeof(T),
240                               reinterpret_cast<char*>(dst_ptr));
241     tensor->clear_tensor_content();
242   } else if (sizeof(T) > 1) {
243     // Copy raw bytes to temp array first, then cast.
244     gtl::InlinedVector<T, 64> tmp(new_num_values);
245     port::CopySubrangeToArray(tensor->tensor_content(), 0,
246                               new_num_values * sizeof(T),
247                               reinterpret_cast<char*>(tmp.data()));
248     tensor->clear_tensor_content();
249     const T* begin = tmp.begin();
250     const T* end = tmp.end();
251     TypeHelper::AddValues(begin, end, tensor);
252   } else {
253     // Copy and cast, one byte at a time.
254     for (int64_t i = 0; i < new_num_values; ++i) {
255       char c = tensor->tensor_content()[i];
256       TypeHelper::AddValue(static_cast<T>(c), tensor);
257     }
258     tensor->clear_tensor_content();
259   }
260   return true;
261 }
262 
263 template <typename T>
PackedValuesNotEqual(T a,T b)264 inline bool PackedValuesNotEqual(T a, T b) {
265   return a != b;
266 }
267 template <>
PackedValuesNotEqual(float a,float b)268 inline bool PackedValuesNotEqual(float a, float b) {
269   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
270 }
271 template <>
PackedValuesNotEqual(double a,double b)272 inline bool PackedValuesNotEqual(double a, double b) {
273   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
274 }
275 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)276 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
277                                  const std::complex<RealType>& b) {
278   return PackedValuesNotEqual(a.real(), b.real()) ||
279          PackedValuesNotEqual(a.imag(), b.imag());
280 }
281 
282 // Integer can't be negative zero.
283 template <typename T,
284           typename std::enable_if<std::is_integral<T>::value>::type* = nullptr>
IsNegativeZero(T value)285 static bool IsNegativeZero(T value) {
286   return false;
287 }
288 
289 template <typename T,
290           typename std::enable_if<!std::is_integral<T>::value>::type* = nullptr>
IsNegativeZero(T value)291 static bool IsNegativeZero(T value) {
292   return value == T(0) && std::signbit(value);
293 }
294 
295 template <typename T>
IsNegativeZero(std::complex<T> value)296 static bool IsNegativeZero(std::complex<T> value) {
297   return IsNegativeZero(value.real()) || IsNegativeZero(value.imag());
298 }
299 
IsNegativeZero(Eigen::QUInt8 value)300 static bool IsNegativeZero(Eigen::QUInt8 value) { return false; }
IsNegativeZero(Eigen::QInt8 value)301 static bool IsNegativeZero(Eigen::QInt8 value) { return false; }
IsNegativeZero(Eigen::QUInt16 value)302 static bool IsNegativeZero(Eigen::QUInt16 value) { return false; }
IsNegativeZero(Eigen::QInt16 value)303 static bool IsNegativeZero(Eigen::QInt16 value) { return false; }
IsNegativeZero(Eigen::QInt32 value)304 static bool IsNegativeZero(Eigen::QInt32 value) { return false; }
IsNegativeZero(Eigen::half value)305 static bool IsNegativeZero(Eigen::half value) {
306   return IsNegativeZero<float>(value);
307 }
IsNegativeZero(tensorflow::bfloat16 value)308 static bool IsNegativeZero(tensorflow::bfloat16 value) {
309   return IsNegativeZero<float>(float(value));
310 }
311 
312 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)313 bool CompressRepeatedField(float min_compression_ratio,
314                            const TensorShape& shape, TensorProto* tensor) {
315   using TypeHelper = internal::TensorProtoHelper<T>;
316   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
317   const int64_t num_tensor_values = shape.num_elements();
318   const int64_t num_proto_values = TypeHelper::NumValues(*tensor);
319 
320   // Notice that for complex types the tensor is stored as an array of up to
321   // 2 * num_tensor_values real values (real and imaginary parts), possibly
322   // truncated. A 0-splat does not need any value present and is maximally
323   // compressed.
324   if (num_proto_values == 0) return false;
325 
326   const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
327   int64_t last_index = 0;
328   for (int64_t i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
329     const T cur_value = TypeHelper::GetValue(i, *tensor);
330     if (PackedValuesNotEqual(cur_value, last_value)) {
331       last_index = i + 1;
332     }
333   }
334 
335   // Detect all zeroes tensors: this is default value and the content can be
336   // erased entirely.
337   if (last_index == 0 && last_value == T(0) && !IsNegativeZero(last_value)) {
338     TypeHelper::Truncate(0, tensor);
339     return true;
340   }
341 
342   const int64_t num_truncated_proto_values = last_index + 1;
343   const int64_t num_bytes_as_field =
344       num_truncated_proto_values * sizeof(FieldType);
345   const int64_t num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
346   const int64_t num_bytes_before = num_proto_values * sizeof(FieldType);
347   if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
348       static_cast<int64_t>(num_bytes_before / min_compression_ratio)) {
349     return false;
350   }
351   if (num_bytes_as_field <= num_bytes_as_tensor_content) {
352     TypeHelper::Truncate(num_truncated_proto_values, tensor);
353   } else {
354     gtl::InlinedVector<T, 64> tmp;
355     if (num_proto_values == 1) {
356       // Splat case.
357       tmp.resize(num_tensor_values, last_value);
358     } else {
359       tmp.resize(num_tensor_values, T(0));
360       TypeHelper::CopyValues(tmp.begin(), *tensor);
361     }
362     TypeHelper::Truncate(0, tensor);
363     port::CopyFromArray(tensor->mutable_tensor_content(),
364                         reinterpret_cast<const char*>(tmp.data()),
365                         num_bytes_as_tensor_content);
366   }
367   return true;
368 }
369 
370 template <typename T>
CompressTensorProtoInPlaceImpl(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)371 bool CompressTensorProtoInPlaceImpl(int64_t min_num_elements,
372                                     float min_compression_ratio,
373                                     TensorProto* tensor) {
374   const TensorShape shape(tensor->tensor_shape());
375   const int64_t num_tensor_values = shape.num_elements();
376   if (num_tensor_values < min_num_elements) {
377     return false;
378   }
379   if (tensor->tensor_content().empty()) {
380     return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
381   } else {
382     return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
383   }
384   return true;
385 }
386 
387 }  // namespace internal
388 
389 #define HANDLE_COMPRESS_CASE(TF_TYPE)                                  \
390   case TF_TYPE:                                                        \
391     return internal::CompressTensorProtoInPlaceImpl<                   \
392         EnumToDataType<TF_TYPE>::Type>(min_num_elements,               \
393                                        min_compression_ratio, tensor); \
394     break
395 
CompressTensorProtoInPlace(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)396 bool CompressTensorProtoInPlace(int64_t min_num_elements,
397                                 float min_compression_ratio,
398                                 TensorProto* tensor) {
399   switch (tensor->dtype()) {
400     HANDLE_COMPRESS_CASE(DT_FLOAT);
401     HANDLE_COMPRESS_CASE(DT_DOUBLE);
402     HANDLE_COMPRESS_CASE(DT_COMPLEX64);
403     HANDLE_COMPRESS_CASE(DT_COMPLEX128);
404     HANDLE_COMPRESS_CASE(DT_UINT8);
405     HANDLE_COMPRESS_CASE(DT_INT8);
406     HANDLE_COMPRESS_CASE(DT_UINT16);
407     HANDLE_COMPRESS_CASE(DT_INT16);
408     HANDLE_COMPRESS_CASE(DT_UINT32);
409     HANDLE_COMPRESS_CASE(DT_INT32);
410     HANDLE_COMPRESS_CASE(DT_UINT64);
411     HANDLE_COMPRESS_CASE(DT_INT64);
412     HANDLE_COMPRESS_CASE(DT_BOOL);
413     HANDLE_COMPRESS_CASE(DT_QUINT8);
414     HANDLE_COMPRESS_CASE(DT_QINT8);
415     HANDLE_COMPRESS_CASE(DT_QUINT16);
416     HANDLE_COMPRESS_CASE(DT_QINT16);
417     HANDLE_COMPRESS_CASE(DT_QINT32);
418     HANDLE_COMPRESS_CASE(DT_HALF);
419     HANDLE_COMPRESS_CASE(DT_BFLOAT16);
420     default:
421       return false;
422   }
423 }
424 
425 #undef HANDLE_COMPRESS_CASE
426 
MakeShape(const Tensor & shape,TensorShape * out)427 Status MakeShape(const Tensor& shape, TensorShape* out) {
428   if (!TensorShapeUtils::IsVector(shape.shape())) {
429     return errors::InvalidArgument(
430         "shape must be a vector of {int32,int64}, got shape ",
431         shape.shape().DebugString());
432   }
433   if (shape.dtype() == DataType::DT_INT32) {
434     auto vec = shape.flat<int32>();
435     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
436   } else if (shape.dtype() == DataType::DT_INT64) {
437     auto vec = shape.flat<int64_t>();
438     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
439   } else {
440     return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
441   }
442 }
443 
444 }  // namespace tensor
445 }  // namespace tensorflow
446