1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/type_traits.h"
24 #include "tensorflow/core/framework/variant.h"
25 #include "tensorflow/core/lib/core/stringpiece.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/tensor_coding.h"
28 #include "tensorflow/core/platform/types.h"
29
30 namespace tensorflow {
31 namespace tensor {
32
DeepCopy(const Tensor & other)33 Tensor DeepCopy(const Tensor& other) {
34 Tensor tmp = Tensor(other.dtype(), other.shape());
35 DeepCopy(other, &tmp);
36 return tmp;
37 }
38
DeepCopy(const Tensor & input,Tensor * output)39 void DeepCopy(const Tensor& input, Tensor* output) {
40 if (DataTypeCanUseMemcpy(input.dtype())) {
41 if (input.NumElements() > 0) {
42 StringPiece input_data = input.tensor_data();
43
44 // We use StringPiece as a convenient map over the tensor buffer,
45 // but we cast the type to get to the underlying buffer to do the
46 // copy.
47 StringPiece output_data = output->tensor_data();
48 memcpy(const_cast<char*>(output_data.data()), input_data.data(),
49 input_data.size());
50 }
51 } else if (input.dtype() == DT_STRING) {
52 output->unaligned_flat<tstring>() = input.unaligned_flat<tstring>();
53 } else {
54 CHECK_EQ(DT_VARIANT, input.dtype());
55 output->unaligned_flat<Variant>() = input.unaligned_flat<Variant>();
56 }
57 }
58
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)59 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
60 if (tensors.empty()) {
61 return errors::InvalidArgument("Cannot concatenate zero tensors");
62 }
63 int64_t total_dim0_size = 0;
64 for (const Tensor& tensor : tensors) {
65 if (tensor.dims() == 0) {
66 return errors::InvalidArgument(
67 "Cannot concatenate a zero-dimensional tensor");
68 }
69 total_dim0_size += tensor.dim_size(0);
70 }
71 TensorShape shape = tensors[0].shape();
72 shape.set_dim(0, total_dim0_size);
73
74 const DataType dtype = tensors[0].dtype();
75 for (int i = 1; i < tensors.size(); ++i) {
76 if (tensors[i].dtype() != dtype) {
77 return errors::InvalidArgument(
78 "Cannot concatenate tensors that have different data types.", " Got ",
79 DataTypeString(dtype), " and ", DataTypeString(tensors[i].dtype()),
80 ".");
81 }
82 }
83 *result = Tensor(dtype, shape);
84
85 // We use StringPiece as a convenient map over the tensor buffer,
86 // but we cast the type to get to the underlying buffer to do the
87 // copy.
88 StringPiece to_data = result->tensor_data();
89
90 if (DataTypeCanUseMemcpy(dtype)) {
91 int64_t offset = 0;
92 for (const Tensor& tensor : tensors) {
93 StringPiece from_data = tensor.tensor_data();
94 CHECK_LE(offset + from_data.size(), to_data.size());
95 memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
96 from_data.size());
97
98 offset += from_data.size();
99 }
100 } else {
101 if (dtype != DT_STRING) {
102 return errors::Internal("Unexpected data type");
103 }
104 tstring* to_strings =
105 reinterpret_cast<tstring*>(const_cast<char*>(to_data.data()));
106
107 int64_t offset = 0;
108 for (const Tensor& tensor : tensors) {
109 auto from_strings = tensor.flat<tstring>();
110 CHECK_LE(offset + tensor.NumElements(), result->NumElements());
111 for (int i = 0; i < tensor.NumElements(); ++i) {
112 to_strings[offset + i] = from_strings(i);
113 }
114
115 offset += tensor.NumElements();
116 }
117 }
118
119 return Status::OK();
120 }
121
Split(const Tensor & tensor,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * result)122 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
123 std::vector<Tensor>* result) {
124 if (tensor.dims() == 0) {
125 return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
126 }
127 int64_t total_size = 0;
128 for (int64_t size : sizes) {
129 total_size += size;
130 }
131 if (total_size != tensor.dim_size(0)) {
132 return errors::InvalidArgument(
133 "The values in 'sizes' do not sum to the zeroth-dimension size of "
134 "'tensor'");
135 }
136
137 StringPiece from_data = tensor.tensor_data();
138
139 if (DataTypeCanUseMemcpy(tensor.dtype())) {
140 int64_t offset = 0;
141 for (int64_t size : sizes) {
142 TensorShape shape = tensor.shape();
143 shape.set_dim(0, size);
144 result->emplace_back(tensor.dtype(), shape);
145 Tensor* split = &(*result)[result->size() - 1];
146
147 // We use StringPiece as a convenient map over the tensor buffer,
148 // but we cast the type to get to the underlying buffer to do the
149 // copy.
150 StringPiece to_data = split->tensor_data();
151 CHECK_LE(offset + to_data.size(), from_data.size());
152 memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
153 to_data.size());
154
155 offset += to_data.size();
156 }
157 } else {
158 if (tensor.dtype() != DT_STRING) {
159 return errors::Internal("Unexpected data type");
160 }
161 auto from_strings = tensor.flat<tstring>();
162
163 int64_t offset = 0;
164 for (int64_t size : sizes) {
165 TensorShape shape = tensor.shape();
166 shape.set_dim(0, size);
167 result->emplace_back(tensor.dtype(), shape);
168 Tensor& split = (*result)[result->size() - 1];
169 tstring* to_strings = reinterpret_cast<tstring*>(
170 const_cast<char*>(split.tensor_data().data()));
171
172 CHECK_LE(offset + split.NumElements(), tensor.NumElements());
173 for (int i = 0; i < split.NumElements(); ++i) {
174 to_strings[i] = from_strings(offset + i);
175 }
176
177 offset += split.NumElements();
178 }
179 }
180
181 return Status::OK();
182 }
183
184 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)185 void SetTensorProtoShape(std::vector<size_t> shape,
186 TensorShapeProto* shape_proto) {
187 for (auto dim : shape) {
188 shape_proto->mutable_dim()->Add()->set_size(dim);
189 }
190 }
191
192 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)193 bool CompressTensorContent(float min_compression_ratio,
194 const TensorShape& shape, TensorProto* tensor) {
195 using TypeHelper = internal::TensorProtoHelper<T>;
196 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
197 const int64_t num_tensor_values = shape.num_elements();
198 const int64_t num_bytes = tensor->tensor_content().size();
199 const int64_t num_raw_values = num_bytes / sizeof(T);
200 if (num_raw_values != num_tensor_values) {
201 // Invalid or too small.
202 return false;
203 }
204 int64_t last_offset = num_bytes - 1;
205 int64_t prev_offset = last_offset - sizeof(T);
206 // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
207 // starting from the end, to find the last pair of elements that are not
208 // identical.
209 while (prev_offset >= 0) {
210 if (tensor->tensor_content()[prev_offset] !=
211 tensor->tensor_content()[last_offset]) {
212 break;
213 }
214 --last_offset;
215 --prev_offset;
216 }
217 if (prev_offset == -1) {
218 // It this is a splat of value 0, it does not need an explicit value, just
219 // erase the content.
220 T splat_value;
221 port::CopySubrangeToArray(tensor->tensor_content(), 0, sizeof(T),
222 reinterpret_cast<char*>(&splat_value));
223 if (splat_value == T(0)) {
224 tensor->clear_tensor_content();
225 return true;
226 }
227 }
228 // Round up to the next whole number of element of type T.
229 const int64_t new_num_values = last_offset / sizeof(T) + 1;
230 if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
231 static_cast<int64>(num_bytes / min_compression_ratio)) {
232 return false;
233 }
234 // Copy values to truncated repeated field.
235 if (sizeof(FieldType) == sizeof(T)) {
236 FieldType* dst_ptr =
237 TypeHelper::AppendUninitialized(new_num_values, tensor);
238 port::CopySubrangeToArray(tensor->tensor_content(), 0,
239 new_num_values * sizeof(T),
240 reinterpret_cast<char*>(dst_ptr));
241 tensor->clear_tensor_content();
242 } else if (sizeof(T) > 1) {
243 // Copy raw bytes to temp array first, then cast.
244 gtl::InlinedVector<T, 64> tmp(new_num_values);
245 port::CopySubrangeToArray(tensor->tensor_content(), 0,
246 new_num_values * sizeof(T),
247 reinterpret_cast<char*>(tmp.data()));
248 tensor->clear_tensor_content();
249 const T* begin = tmp.begin();
250 const T* end = tmp.end();
251 TypeHelper::AddValues(begin, end, tensor);
252 } else {
253 // Copy and cast, one byte at a time.
254 for (int64_t i = 0; i < new_num_values; ++i) {
255 char c = tensor->tensor_content()[i];
256 TypeHelper::AddValue(static_cast<T>(c), tensor);
257 }
258 tensor->clear_tensor_content();
259 }
260 return true;
261 }
262
263 template <typename T>
PackedValuesNotEqual(T a,T b)264 inline bool PackedValuesNotEqual(T a, T b) {
265 return a != b;
266 }
267 template <>
PackedValuesNotEqual(float a,float b)268 inline bool PackedValuesNotEqual(float a, float b) {
269 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
270 }
271 template <>
PackedValuesNotEqual(double a,double b)272 inline bool PackedValuesNotEqual(double a, double b) {
273 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
274 }
275 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)276 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
277 const std::complex<RealType>& b) {
278 return PackedValuesNotEqual(a.real(), b.real()) ||
279 PackedValuesNotEqual(a.imag(), b.imag());
280 }
281
282 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)283 bool CompressRepeatedField(float min_compression_ratio,
284 const TensorShape& shape, TensorProto* tensor) {
285 using TypeHelper = internal::TensorProtoHelper<T>;
286 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
287 const int64_t num_tensor_values = shape.num_elements();
288 const int64_t num_proto_values = TypeHelper::NumValues(*tensor);
289
290 // Notice that for complex types the tensor is stored as an array of up to
291 // 2 * num_tensor_values real values (real and imaginary parts), possibly
292 // truncated. A 0-splat does not need any value present and is maximally
293 // compressed.
294 if (num_proto_values == 0) return false;
295
296 const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
297 int64_t last_index = 0;
298 for (int64_t i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
299 const T cur_value = TypeHelper::GetValue(i, *tensor);
300 if (PackedValuesNotEqual(cur_value, last_value)) {
301 last_index = i + 1;
302 }
303 }
304
305 // Detect all zeroes tensors: this is default value and the content can be
306 // erased entirely.
307 if (last_index == 0 && last_value == T(0)) {
308 TypeHelper::Truncate(0, tensor);
309 return true;
310 }
311
312 const int64_t num_truncated_proto_values = last_index + 1;
313 const int64_t num_bytes_as_field =
314 num_truncated_proto_values * sizeof(FieldType);
315 const int64_t num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
316 const int64_t num_bytes_before = num_proto_values * sizeof(FieldType);
317 if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
318 static_cast<int64>(num_bytes_before / min_compression_ratio)) {
319 return false;
320 }
321 if (num_bytes_as_field <= num_bytes_as_tensor_content) {
322 TypeHelper::Truncate(num_truncated_proto_values, tensor);
323 } else {
324 gtl::InlinedVector<T, 64> tmp;
325 if (num_proto_values == 1) {
326 // Splat case.
327 tmp.resize(num_tensor_values, last_value);
328 } else {
329 tmp.resize(num_tensor_values, T(0));
330 TypeHelper::CopyValues(tmp.begin(), *tensor);
331 }
332 TypeHelper::Truncate(0, tensor);
333 port::CopyFromArray(tensor->mutable_tensor_content(),
334 reinterpret_cast<const char*>(tmp.data()),
335 num_bytes_as_tensor_content);
336 }
337 return true;
338 }
339
340 template <typename T>
CompressTensorProtoInPlaceImpl(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)341 bool CompressTensorProtoInPlaceImpl(int64_t min_num_elements,
342 float min_compression_ratio,
343 TensorProto* tensor) {
344 const TensorShape shape(tensor->tensor_shape());
345 const int64_t num_tensor_values = shape.num_elements();
346 if (num_tensor_values < min_num_elements) {
347 return false;
348 }
349 if (tensor->tensor_content().empty()) {
350 return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
351 } else {
352 return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
353 }
354 return true;
355 }
356
357 } // namespace internal
358
359 #define HANDLE_COMPRESS_CASE(TF_TYPE) \
360 case TF_TYPE: \
361 return internal::CompressTensorProtoInPlaceImpl< \
362 EnumToDataType<TF_TYPE>::Type>(min_num_elements, \
363 min_compression_ratio, tensor); \
364 break
365
CompressTensorProtoInPlace(int64_t min_num_elements,float min_compression_ratio,TensorProto * tensor)366 bool CompressTensorProtoInPlace(int64_t min_num_elements,
367 float min_compression_ratio,
368 TensorProto* tensor) {
369 switch (tensor->dtype()) {
370 HANDLE_COMPRESS_CASE(DT_FLOAT);
371 HANDLE_COMPRESS_CASE(DT_DOUBLE);
372 HANDLE_COMPRESS_CASE(DT_COMPLEX64);
373 HANDLE_COMPRESS_CASE(DT_COMPLEX128);
374 HANDLE_COMPRESS_CASE(DT_UINT8);
375 HANDLE_COMPRESS_CASE(DT_INT8);
376 HANDLE_COMPRESS_CASE(DT_UINT16);
377 HANDLE_COMPRESS_CASE(DT_INT16);
378 HANDLE_COMPRESS_CASE(DT_UINT32);
379 HANDLE_COMPRESS_CASE(DT_INT32);
380 HANDLE_COMPRESS_CASE(DT_UINT64);
381 HANDLE_COMPRESS_CASE(DT_INT64);
382 HANDLE_COMPRESS_CASE(DT_BOOL);
383 HANDLE_COMPRESS_CASE(DT_QUINT8);
384 HANDLE_COMPRESS_CASE(DT_QINT8);
385 HANDLE_COMPRESS_CASE(DT_QUINT16);
386 HANDLE_COMPRESS_CASE(DT_QINT16);
387 HANDLE_COMPRESS_CASE(DT_QINT32);
388 HANDLE_COMPRESS_CASE(DT_HALF);
389 HANDLE_COMPRESS_CASE(DT_BFLOAT16);
390 default:
391 return false;
392 }
393 }
394
395 #undef HANDLE_COMPRESS_CASE
396
MakeShape(const Tensor & shape,TensorShape * out)397 Status MakeShape(const Tensor& shape, TensorShape* out) {
398 if (!TensorShapeUtils::IsVector(shape.shape())) {
399 return errors::InvalidArgument(
400 "shape must be a vector of {int32,int64}, got shape ",
401 shape.shape().DebugString());
402 }
403 if (shape.dtype() == DataType::DT_INT32) {
404 auto vec = shape.flat<int32>();
405 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
406 } else if (shape.dtype() == DataType::DT_INT64) {
407 auto vec = shape.flat<int64>();
408 return TensorShapeUtils::MakeShape(vec.data(), vec.size(), out);
409 } else {
410 return errors::InvalidArgument("shape must be a vector of {int32,int64}.");
411 }
412 }
413
414 } // namespace tensor
415 } // namespace tensorflow
416