• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_util.h"
17 
18 #include <cmath>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/type_traits.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/platform/tensor_coding.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace tensor {
31 
DeepCopy(const Tensor & other)32 Tensor DeepCopy(const Tensor& other) {
33   Tensor tmp = Tensor(other.dtype(), other.shape());
34   if (DataTypeCanUseMemcpy(other.dtype())) {
35     if (other.NumElements() > 0) {
36       StringPiece other_data = other.tensor_data();
37 
38       // We use StringPiece as a convenient map over the tensor buffer,
39       // but we cast the type to get to the underlying buffer to do the
40       // copy.
41       StringPiece tmp_data = tmp.tensor_data();
42       memcpy(const_cast<char*>(tmp_data.data()), other_data.data(),
43              other_data.size());
44     }
45   } else if (other.dtype() == DT_STRING) {
46     tmp.unaligned_flat<string>() = other.unaligned_flat<string>();
47   } else {
48     CHECK_EQ(DT_VARIANT, other.dtype());
49     tmp.unaligned_flat<Variant>() = other.unaligned_flat<Variant>();
50   }
51   return tmp;
52 }
53 
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)54 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
55   if (tensors.empty()) {
56     return errors::InvalidArgument("Cannot concatenate zero tensors");
57   }
58   int64 total_dim0_size = 0;
59   for (const Tensor& tensor : tensors) {
60     if (tensor.dims() == 0) {
61       return errors::InvalidArgument(
62           "Cannot concatenate a zero-dimensional tensor");
63     }
64     total_dim0_size += tensor.dim_size(0);
65   }
66   TensorShape shape = tensors[0].shape();
67   shape.set_dim(0, total_dim0_size);
68 
69   const DataType dtype = tensors[0].dtype();
70   for (int i = 1; i < tensors.size(); ++i) {
71     if (tensors[i].dtype() != dtype) {
72       return errors::InvalidArgument(
73           "Cannot concatenate tensors that have different data types");
74     }
75   }
76   *result = Tensor(dtype, shape);
77 
78   // We use StringPiece as a convenient map over the tensor buffer,
79   // but we cast the type to get to the underlying buffer to do the
80   // copy.
81   StringPiece to_data = result->tensor_data();
82 
83   if (DataTypeCanUseMemcpy(dtype)) {
84     int64 offset = 0;
85     for (const Tensor& tensor : tensors) {
86       StringPiece from_data = tensor.tensor_data();
87       CHECK_LE(offset + from_data.size(), to_data.size());
88       memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
89              from_data.size());
90 
91       offset += from_data.size();
92     }
93   } else {
94     if (dtype != DT_STRING) {
95       return errors::Internal("Unexpected data type");
96     }
97     string* to_strings =
98         reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
99 
100     int64 offset = 0;
101     for (const Tensor& tensor : tensors) {
102       auto from_strings = tensor.flat<string>();
103       CHECK_LE(offset + tensor.NumElements(), result->NumElements());
104       for (int i = 0; i < tensor.NumElements(); ++i) {
105         to_strings[offset + i] = from_strings(i);
106       }
107 
108       offset += tensor.NumElements();
109     }
110   }
111 
112   return Status::OK();
113 }
114 
Split(const Tensor & tensor,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * result)115 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
116              std::vector<Tensor>* result) {
117   if (tensor.dims() == 0) {
118     return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
119   }
120   int64 total_size = 0;
121   for (int64 size : sizes) {
122     total_size += size;
123   }
124   if (total_size != tensor.dim_size(0)) {
125     return errors::InvalidArgument(
126         "The values in 'sizes' do not sum to the zeroth-dimension size of "
127         "'tensor'");
128   }
129 
130   StringPiece from_data = tensor.tensor_data();
131 
132   if (DataTypeCanUseMemcpy(tensor.dtype())) {
133     int64 offset = 0;
134     for (int64 size : sizes) {
135       TensorShape shape = tensor.shape();
136       shape.set_dim(0, size);
137       result->emplace_back(tensor.dtype(), shape);
138       Tensor* split = &(*result)[result->size() - 1];
139 
140       // We use StringPiece as a convenient map over the tensor buffer,
141       // but we cast the type to get to the underlying buffer to do the
142       // copy.
143       StringPiece to_data = split->tensor_data();
144       CHECK_LE(offset + to_data.size(), from_data.size());
145       memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
146              to_data.size());
147 
148       offset += to_data.size();
149     }
150   } else {
151     if (tensor.dtype() != DT_STRING) {
152       return errors::Internal("Unexpected data type");
153     }
154     auto from_strings = tensor.flat<string>();
155 
156     int64 offset = 0;
157     for (int64 size : sizes) {
158       TensorShape shape = tensor.shape();
159       shape.set_dim(0, size);
160       result->emplace_back(tensor.dtype(), shape);
161       Tensor& split = (*result)[result->size() - 1];
162       string* to_strings = reinterpret_cast<string*>(
163           const_cast<char*>(split.tensor_data().data()));
164 
165       CHECK_LE(offset + split.NumElements(), tensor.NumElements());
166       for (int i = 0; i < split.NumElements(); ++i) {
167         to_strings[i] = from_strings(offset + i);
168       }
169 
170       offset += split.NumElements();
171     }
172   }
173 
174   return Status::OK();
175 }
176 
177 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)178 void SetTensorProtoShape(std::vector<size_t> shape,
179                          TensorShapeProto* shape_proto) {
180   for (auto dim : shape) {
181     shape_proto->mutable_dim()->Add()->set_size(dim);
182   }
183 }
184 
185 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)186 bool CompressTensorContent(float min_compression_ratio,
187                            const TensorShape& shape, TensorProto* tensor) {
188   using TypeHelper = internal::TensorProtoHelper<T>;
189   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
190   const int64 num_tensor_values = shape.num_elements();
191   const int64 num_bytes = tensor->tensor_content().size();
192   const int64 num_raw_values = num_bytes / sizeof(T);
193   if (num_raw_values != num_tensor_values) {
194     // Invalid or too small.
195     return false;
196   }
197   int64 last_offset = num_bytes - 1;
198   int64 prev_offset = last_offset - sizeof(T);
199   // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
200   // starting from the end, to find the last pair of elements that are not
201   // identical.
202   while (prev_offset >= 0) {
203     if (tensor->tensor_content()[prev_offset] !=
204         tensor->tensor_content()[last_offset]) {
205       break;
206     }
207     --last_offset;
208     --prev_offset;
209   }
210   // Round up to the next whole number of element of type T.
211   const int64 new_num_values = last_offset / sizeof(T) + 1;
212   if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
213       static_cast<int64>(num_bytes / min_compression_ratio)) {
214     return false;
215   }
216   // Copy values to truncated repeated field.
217   if (sizeof(FieldType) == sizeof(T)) {
218     FieldType* dst_ptr =
219         TypeHelper::AppendUninitialized(new_num_values, tensor);
220     port::CopySubrangeToArray(tensor->tensor_content(), 0,
221                               new_num_values * sizeof(T),
222                               reinterpret_cast<char*>(dst_ptr));
223     tensor->clear_tensor_content();
224   } else if (sizeof(T) > 1) {
225     // Copy raw bytes to temp array first, then cast.
226     gtl::InlinedVector<T, 64> tmp(new_num_values);
227     port::CopySubrangeToArray(tensor->tensor_content(), 0,
228                               new_num_values * sizeof(T),
229                               reinterpret_cast<char*>(tmp.data()));
230     tensor->clear_tensor_content();
231     const T* begin = tmp.begin();
232     const T* end = tmp.end();
233     TypeHelper::AddValues(begin, end, tensor);
234   } else {
235     // Copy and cast, one byte at a time.
236     for (int64 i = 0; i < new_num_values; ++i) {
237       char c = tensor->tensor_content()[i];
238       TypeHelper::AddValue(static_cast<T>(c), tensor);
239     }
240     tensor->clear_tensor_content();
241   }
242   return true;
243 }
244 
245 template <typename T>
PackedValuesNotEqual(T a,T b)246 inline bool PackedValuesNotEqual(T a, T b) {
247   return a != b;
248 }
249 template <>
PackedValuesNotEqual(float a,float b)250 inline bool PackedValuesNotEqual(float a, float b) {
251   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
252 }
253 template <>
PackedValuesNotEqual(double a,double b)254 inline bool PackedValuesNotEqual(double a, double b) {
255   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
256 }
257 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)258 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
259                                  const std::complex<RealType>& b) {
260   return PackedValuesNotEqual(a.real(), b.real()) ||
261          PackedValuesNotEqual(a.imag(), b.imag());
262 }
263 
264 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)265 bool CompressRepeatedField(float min_compression_ratio,
266                            const TensorShape& shape, TensorProto* tensor) {
267   using TypeHelper = internal::TensorProtoHelper<T>;
268   using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
269   const int64 num_tensor_values = shape.num_elements();
270   // Notice that for complex types the tensor is stored as an array of up to
271   // 2 * num_tensor_values real values (real and imaginary parts), possibly
272   // truncated.
273   const int64 num_proto_values = TypeHelper::NumValues(*tensor);
274   if (num_proto_values != num_tensor_values) {
275     // Already compressed or invalid.
276     return false;
277   }
278   const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
279   int64 last_index = 0;
280   for (int64 i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
281     const T cur_value = TypeHelper::GetValue(i, *tensor);
282     if (PackedValuesNotEqual(cur_value, last_value)) {
283       last_index = i + 1;
284     }
285   }
286   const int64 num_truncated_proto_values = last_index + 1;
287   const int64 num_bytes_as_field =
288       num_truncated_proto_values * sizeof(FieldType);
289   const int64 num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
290   const int64 num_bytes_before = num_proto_values * sizeof(FieldType);
291   if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
292       static_cast<int64>(num_bytes_before / min_compression_ratio)) {
293     return false;
294   }
295   if (num_bytes_as_field <= num_bytes_as_tensor_content) {
296     TypeHelper::Truncate(num_truncated_proto_values, tensor);
297   } else {
298     gtl::InlinedVector<T, 64> tmp(num_tensor_values);
299     TypeHelper::CopyValues(tmp.begin(), *tensor);
300     TypeHelper::Truncate(0, tensor);
301     port::CopyFromArray(tensor->mutable_tensor_content(),
302                         reinterpret_cast<const char*>(tmp.data()),
303                         num_bytes_as_tensor_content);
304   }
305   return true;
306 }
307 
308 template <typename T>
CompressTensorProtoInPlaceImpl(int64 min_num_elements,float min_compression_ratio,TensorProto * tensor)309 bool CompressTensorProtoInPlaceImpl(int64 min_num_elements,
310                                     float min_compression_ratio,
311                                     TensorProto* tensor) {
312   const TensorShape shape(tensor->tensor_shape());
313   const int64 num_tensor_values = shape.num_elements();
314   if (num_tensor_values < min_num_elements) {
315     return false;
316   }
317   if (tensor->tensor_content().empty()) {
318     return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
319   } else {
320     return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
321   }
322   return true;
323 }
324 
325 }  // namespace internal
326 
327 #define HANDLE_COMPRESS_CASE(TF_TYPE)                                  \
328   case TF_TYPE:                                                        \
329     return internal::CompressTensorProtoInPlaceImpl<                   \
330         EnumToDataType<TF_TYPE>::Type>(min_num_elements,               \
331                                        min_compression_ratio, tensor); \
332     break
333 
CompressTensorProtoInPlace(int64 min_num_elements,float min_compression_ratio,TensorProto * tensor)334 bool CompressTensorProtoInPlace(int64 min_num_elements,
335                                 float min_compression_ratio,
336                                 TensorProto* tensor) {
337   switch (tensor->dtype()) {
338     HANDLE_COMPRESS_CASE(DT_FLOAT);
339     HANDLE_COMPRESS_CASE(DT_DOUBLE);
340     HANDLE_COMPRESS_CASE(DT_COMPLEX64);
341     HANDLE_COMPRESS_CASE(DT_COMPLEX128);
342     HANDLE_COMPRESS_CASE(DT_UINT8);
343     HANDLE_COMPRESS_CASE(DT_INT8);
344     HANDLE_COMPRESS_CASE(DT_UINT16);
345     HANDLE_COMPRESS_CASE(DT_INT16);
346     HANDLE_COMPRESS_CASE(DT_UINT32);
347     HANDLE_COMPRESS_CASE(DT_INT32);
348     HANDLE_COMPRESS_CASE(DT_UINT64);
349     HANDLE_COMPRESS_CASE(DT_INT64);
350     HANDLE_COMPRESS_CASE(DT_BOOL);
351     HANDLE_COMPRESS_CASE(DT_QUINT8);
352     HANDLE_COMPRESS_CASE(DT_QINT8);
353     HANDLE_COMPRESS_CASE(DT_QUINT16);
354     HANDLE_COMPRESS_CASE(DT_QINT16);
355     HANDLE_COMPRESS_CASE(DT_QINT32);
356     HANDLE_COMPRESS_CASE(DT_HALF);
357     HANDLE_COMPRESS_CASE(DT_BFLOAT16);
358     default:
359       return false;
360   }
361 }
362 
363 #undef HANDLE_COMPRESS_CASE
364 
365 }  // namespace tensor
366 }  // namespace tensorflow
367