• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_util.h"
17 
18 #include <cmath>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/type_traits.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/tensor_coding.h"
28 #include "tensorflow/core/platform/types.h"
29 
30 namespace tensorflow {
31 namespace tensor {
32 
DeepCopy(const Tensor & other)33 Tensor DeepCopy(const Tensor& other) {
34   Tensor tmp = Tensor(other.dtype(), other.shape());
35   DeepCopy(other, &tmp);
36   return tmp;
37 }
38 
DeepCopy(const Tensor & input,Tensor * output)39 void DeepCopy(const Tensor& input, Tensor* output) {
40   if (DataTypeCanUseMemcpy(input.dtype())) {
41     if (input.NumElements() > 0) {
42       StringPiece input_data = input.tensor_data();
43 
44       // We use StringPiece as a convenient map over the tensor buffer,
45       // but we cast the type to get to the underlying buffer to do the
46       // copy.
47       StringPiece output_data = output->tensor_data();
48       memcpy(const_cast<char*>(output_data.data()), input_data.data(),
49              input_data.size());
50     }
51   } else if (input.dtype() == DT_STRING) {
52     output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
53   } else {
54     CHECK_EQ(DT_VARIANT, input.dtype());
55     output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
56   }
57 }
58 
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)59 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
60   if (tensors.empty()) {
61     return errors::InvalidArgument("Cannot concatenate zero tensors");
62   }
63   int64_t total_dim0_size = 0;
64   for (const Tensor& tensor : tensors) {
65     if (tensor.dims() == 0) {
66       return errors::InvalidArgument(
67           "Cannot concatenate a zero-dimensional tensor");
68     }
69     total_dim0_size += tensor.dim_size(0);
70   }
71   TensorShape shape = tensors[0].shape();
72   shape.set_dim(0, total_dim0_size);
73 
74   const DataType dtype = tensors[0].dtype();
75   for (int i = 1; i < tensors.size(); ++i) {
76     if (tensors[i].dtype() != dtype) {
77       return errors::InvalidArgument(
78           "Cannot concatenate tensors that have different data types.", " Got ",
79           DataTypeString(dtype), " and ", DataTypeString(tensors[i].dtype()),
80           ".");
81     }
82   }
83   *result = Tensor(dtype, shape);
84 
85   // We use StringPiece as a convenient map over the tensor buffer,
86   // but we cast the type to get to the underlying buffer to do the
87   // copy.
88   StringPiece to_data = result->tensor_data();
89 
90   if (DataTypeCanUseMemcpy(dtype)) {
91     int64_t offset = 0;
92     for (const Tensor& tensor : tensors) {
93       StringPiece from_data = tensor.tensor_data();
94       CHECK_LE(offset + from_data.size(), to_data.size());
95       memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
96              from_data.size());
97 
98       offset += from_data.size();
99     }
100   } else {
101     if (dtype != DT_STRING) {
102       return errors::Internal("Unexpected data type");
103     }
104     tstring* to_strings =
105         reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
106 
107     int64_t offset = 0;
108     for (const Tensor& tensor : tensors) {
109       auto from_strings = tensor.flat<tstring>();
110       CHECK_LE(offset + tensor.NumElements(), result->NumElements());
111       for (int i = 0; i < tensor.NumElements(); ++i) {
112         to_strings[offset + i] = from_strings(i);
113       }
114 
115       offset += tensor.NumElements();
116     }
117   }
118 
119   return Status::OK();
120 }
121 
Split(const Tensor & tensor,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * result)122 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
123              std::vector<Tensor>* result) {
124   if (tensor.dims() == 0) {
125     return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
126   }
127   int64_t total_size = 0;
128   for (int64_t size : sizes) {
129     total_size += size;
130   }
131   if (total_size != tensor.dim_size(0)) {
132     return errors::InvalidArgument(
133         "The values in 'sizes' do not sum to the zeroth-dimension size of "
134         "'tensor'");
135   }
136 
137   StringPiece from_data = tensor.tensor_data();
138 
139   if (DataTypeCanUseMemcpy(tensor.dtype())) {
140     int64_t offset = 0;
141     for (int64_t size : sizes) {
142       TensorShape shape = tensor.shape();
143       shape.set_dim(0, size);
144       result->emplace_back(tensor.dtype(), shape);
145       Tensor* split = &(*result)[result->size() - 1];
146 
147       // We use StringPiece as a convenient map over the tensor buffer,
148       // but we cast the type to get to the underlying buffer to do the
149       // copy.
150       StringPiece to_data = split->tensor_data();
151       CHECK_LE(offset + to_data.size(), from_data.size());
152       memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
153              to_data.size());
154 
155       offset += to_data.size();
156     }
157   } else {
158     if (tensor.dtype() != DT_STRING) {
159       return errors::Internal("Unexpected data type");
160     }
161     auto from_strings = tensor.flat<tstring>();
162 
163     int64_t offset = 0;
164     for (int64_t size : sizes) {
165       TensorShape shape = tensor.shape();
166       shape.set_dim(0, size);
167       result->emplace_back(tensor.dtype(), shape);
168       Tensor& split = (*result)[result->size() - 1];
169       tstring* to_strings = reinterpret_cast<tstring*>(
170           const_cast<char*>(split.tensor_data().data()));
171 
172       CHECK_LE(offset + split.NumElements(), tensor.NumElements());
173       for (int i = 0; i < split.NumElements(); ++i) {
174         to_strings[i] = from_strings(offset + i);
175       }
176 
177       offset += split.NumElements();
178     }
179   }
180 
181   return Status::OK();
182 }
183 
184 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)185 void SetTensorProtoShape(std::vector<size_t> shape,
186                          TensorShapeProto* shape_proto) {
187   for (auto dim : shape) {
188     shape_proto->mutable_dim()->Add()->set_size(dim);
189   }
190 }
191 
192 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)193 bool CompressTensorContent(float min_compression_ratio,
194                            const TensorShape& shape, TensorProto* tensor) {
195   using TypeHelper = internal::TensorProtoHelper<T>;
196   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
197   const int64_t num_tensor_values = shape.num_elements();
198   const int64_t num_bytes = tensor->tensor_content().size();
199   const int64_t num_raw_values = num_bytes / sizeof(T);
200   if (num_raw_values != num_tensor_values) {
201     // Invalid or too small.
202     return false;
203   }
204   int64_t last_offset = num_bytes - 1;
205   int64_t prev_offset = last_offset - sizeof(T);
206   // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
207   // starting from the end, to find the last pair of elements that are not
208   // identical.
209   while (prev_offset >= 0) {
210     if (tensor->tensor_content()[prev_offset] !=
211         tensor->tensor_content()[last_offset]) {
212       break;
213     }
214     --last_offset;
215     --prev_offset;
216   }
217   if (prev_offset == -1) {
218     // It this is a splat of value 0, it does not need an explicit value, just
219     // erase the content.
220     T splat_value;
221     port::CopySubrangeToArray(tensor->tensor_content(), 0, sizeof(T),
222                               reinterpret_cast<char*>(&splat_value));
223     if (splat_value == T(0)) {
224       tensor->clear_tensor_content();
225       return true;
226     }
227   }
228   // Round up to the next whole number of element of type T.
229   const int64_t new_num_values = last_offset / sizeof(T) + 1;
230   if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
231       static_cast<int64>(num_bytes / min_compression_ratio)) {
232     return false;
233   }
234   // Copy values to truncated repeated field.
235   if (sizeof(FieldType) == sizeof(T)) {
236     FieldType* dst_ptr =
237         TypeHelper::AppendUninitialized(new_num_values, tensor);
238     port::CopySubrangeToArray(tensor->tensor_content(), 0,
239                               new_num_values * sizeof(T),
240                               reinterpret_cast<char*>(dst_ptr));
241     tensor->clear_tensor_content();
242   } else if (sizeof(T) > 1) {
243     // Copy raw bytes to temp array first, then cast.
244     gtl::InlinedVector<T, 64> tmp(new_num_values);
245     port::CopySubrangeToArray(tensor->tensor_content(), 0,
246                               new_num_values * sizeof(T),
247                               reinterpret_cast<char*>(tmp.data()));
248     tensor->clear_tensor_content();
249     const T* begin = tmp.begin();
250     const T* end = tmp.end();
251     TypeHelper::AddValues(begin, end, tensor);
252   } else {
253     // Copy and cast, one byte at a time.
254     for (int64_t i = 0; i < new_num_values; ++i) {
255       char c = tensor->tensor_content()[i];
256       TypeHelper::AddValue(static_cast<T>(c), tensor);
257     }
258     tensor->clear_tensor_content();
259   }
260   return true;
261 }
262 
263 template <typename T>
PackedValuesNotEqual(T a,T b)264 inline bool PackedValuesNotEqual(T a, T b) {
265   return a != b;
266 }
267 template <>
PackedValuesNotEqual(float a,float b)268 inline bool PackedValuesNotEqual(float a, float b) {
269   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
270 }
271 template <>
PackedValuesNotEqual(double a,double b)272 inline bool PackedValuesNotEqual(double a, double b) {
273   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
274 }
275 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)276 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
277                                  const std::complex<RealType>& b) {
278   return PackedValuesNotEqual(a.real(), b.real()) ||
279          PackedValuesNotEqual(a.imag(), b.imag());
280 }
281 
282 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)283 bool CompressRepeatedField(float min_compression_ratio,
284                            const TensorShape& shape, TensorProto* tensor) {
285   using TypeHelper = internal::TensorProtoHelper<T>;
286   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
287   const int64_t num_tensor_values = shape.num_elements();
288   const int64_t num_proto_values = TypeHelper::NumValues(*tensor);
289 
290   // Notice that for complex types the tensor is stored as an array of up to
291   // 2 * num_tensor_values real values (real and imaginary parts), possibly
292   // truncated. A 0-splat does not need any value present and is maximally
293   // compressed.
294   if (num_proto_values == 0) return false;
295 
296   const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
297   int64_t last_index = 0;
298   for (int64_t i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
299     const T cur_value = TypeHelper::GetValue(i, *tensor);
300     if (PackedValuesNotEqual(cur_value, last_value)) {
301       last_index = i + 1;
302     }
303   }
304 
305   // Detect all zeroes tensors: this is default value and the content can be
306   // erased entirely.
307   if (last_index == 0 && last_value == T(0)) {
308     TypeHelper::Truncate(0, tensor);
309     return true;
310   }
311 
312   const int64_t num_truncated_proto_values = last_index + 1;
313   const int64_t num_bytes_as_field =
314       num_truncated_proto_values * sizeof(FieldType);
315   const int64_t num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
316   const int64_t num_bytes_before = num_proto_values * sizeof(FieldType);
317   if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
318       static_cast<int64>(num_bytes_before / min_compression_ratio)) {
319     return false;
320   }
321   if (num_bytes_as_field <= num_bytes_as_tensor_content) {
322     TypeHelper::Truncate(num_truncated_proto_values, tensor);
323   } else {
324     gtl::InlinedVector<T, 64> tmp;
325     if (num_proto_values == 1) {
326       // Splat case.
327       tmp.resize(num_tensor_values, last_value);
328     } else {
329       tmp.resize(num_tensor_values, T(0));
330       TypeHelper::CopyValues(tmp.begin(), *tensor);
331     }
332     TypeHelper::Truncate(0, tensor);
333     port::CopyFromArray(tensor->mutable_tensor_content(),
334                         reinterpret_cast<const char*>(tmp.data()),
335                         num_bytes_as_tensor_content);
336   }
337   return true;
338 }
339 
340 template <typename T>
CompressTensorProtoInPlaceImpl(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)341 bool CompressTensorProtoInPlaceImpl(int64_t min_num_elements,
342                                     float min_compression_ratio,
343                                     TensorProto* tensor) {
344   const TensorShape shape(tensor->tensor_shape());
345   const int64_t num_tensor_values = shape.num_elements();
346   if (num_tensor_values < min_num_elements) {
347     return false;
348   }
349   if (tensor->tensor_content().empty()) {
350     return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
351   } else {
352     return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
353   }
354   return true;
355 }
356 
357 }  // namespace internal
358 
359 #define HANDLE_COMPRESS_CASE(TF_TYPE)                                  \
360   case TF_TYPE:                                                        \
361     return internal::CompressTensorProtoInPlaceImpl<                   \
362         EnumToDataType<TF_TYPE>::Type>(min_num_elements,               \
363                                        min_compression_ratio, tensor); \
364     break
365 
CompressTensorProtoInPlace(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)366 bool CompressTensorProtoInPlace(int64_t min_num_elements,
367                                 float min_compression_ratio,
368                                 TensorProto* tensor) {
369   switch (tensor->dtype()) {
370     HANDLE_COMPRESS_CASE(DT_FLOAT);
371     HANDLE_COMPRESS_CASE(DT_DOUBLE);
372     HANDLE_COMPRESS_CASE(DT_COMPLEX64);
373     HANDLE_COMPRESS_CASE(DT_COMPLEX128);
374     HANDLE_COMPRESS_CASE(DT_UINT8);
375     HANDLE_COMPRESS_CASE(DT_INT8);
376     HANDLE_COMPRESS_CASE(DT_UINT16);
377     HANDLE_COMPRESS_CASE(DT_INT16);
378     HANDLE_COMPRESS_CASE(DT_UINT32);
379     HANDLE_COMPRESS_CASE(DT_INT32);
380     HANDLE_COMPRESS_CASE(DT_UINT64);
381     HANDLE_COMPRESS_CASE(DT_INT64);
382     HANDLE_COMPRESS_CASE(DT_BOOL);
383     HANDLE_COMPRESS_CASE(DT_QUINT8);
384     HANDLE_COMPRESS_CASE(DT_QINT8);
385     HANDLE_COMPRESS_CASE(DT_QUINT16);
386     HANDLE_COMPRESS_CASE(DT_QINT16);
387     HANDLE_COMPRESS_CASE(DT_QINT32);
388     HANDLE_COMPRESS_CASE(DT_HALF);
389     HANDLE_COMPRESS_CASE(DT_BFLOAT16);
390     default:
391       return false;
392   }
393 }
394 
395 #undef HANDLE_COMPRESS_CASE
396 
MakeShape(const Tensor & shape,TensorShape * out)397 Status MakeShape(const Tensor& shape, TensorShape* out) {
398   if (!TensorShapeUtils::IsVector(shape.shape())) {
399     return errors::InvalidArgument(
400         "shape must be a vector of {int32,int64}, got shape ",
401         shape.shape().DebugString());
402   }
403   if (shape.dtype() == DataType::DT_INT32) {
404     auto vec = shape.flat<int32>();
405     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
406   } else if (shape.dtype() == DataType::DT_INT64) {
407     auto vec = shape.flat<int64>();
408     return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
409   } else {
410     return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
411   }
412 }
413 
414 }  // namespace tensor
415 }  // namespace tensorflow
416