1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_util.h"
17
18 #include <cmath>
19 #include <vector>
20
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/type_traits.h"
23 #include "tensorflow/core/framework/variant.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 #include "tensorflow/core/platform/tensor_coding.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace tensorflow {
30 namespace tensor {
31
DeepCopy(const Tensor & other)32 Tensor DeepCopy(const Tensor& other) {
33 Tensor tmp = Tensor(other.dtype(), other.shape());
34 if (DataTypeCanUseMemcpy(other.dtype())) {
35 if (other.NumElements() > 0) {
36 StringPiece other_data = other.tensor_data();
37
38 // We use StringPiece as a convenient map over the tensor buffer,
39 // but we cast the type to get to the underlying buffer to do the
40 // copy.
41 StringPiece tmp_data = tmp.tensor_data();
42 memcpy(const_cast<char*>(tmp_data.data()), other_data.data(),
43 other_data.size());
44 }
45 } else if (other.dtype() == DT_STRING) {
46 tmp.unaligned_flat<string>() = other.unaligned_flat<string>();
47 } else {
48 CHECK_EQ(DT_VARIANT, other.dtype());
49 tmp.unaligned_flat<Variant>() = other.unaligned_flat<Variant>();
50 }
51 return tmp;
52 }
53
Concat(const gtl::ArraySlice<Tensor> & tensors,Tensor * result)54 Status Concat(const gtl::ArraySlice<Tensor>& tensors, Tensor* result) {
55 if (tensors.empty()) {
56 return errors::InvalidArgument("Cannot concatenate zero tensors");
57 }
58 int64 total_dim0_size = 0;
59 for (const Tensor& tensor : tensors) {
60 if (tensor.dims() == 0) {
61 return errors::InvalidArgument(
62 "Cannot concatenate a zero-dimensional tensor");
63 }
64 total_dim0_size += tensor.dim_size(0);
65 }
66 TensorShape shape = tensors[0].shape();
67 shape.set_dim(0, total_dim0_size);
68
69 const DataType dtype = tensors[0].dtype();
70 for (int i = 1; i < tensors.size(); ++i) {
71 if (tensors[i].dtype() != dtype) {
72 return errors::InvalidArgument(
73 "Cannot concatenate tensors that have different data types");
74 }
75 }
76 *result = Tensor(dtype, shape);
77
78 // We use StringPiece as a convenient map over the tensor buffer,
79 // but we cast the type to get to the underlying buffer to do the
80 // copy.
81 StringPiece to_data = result->tensor_data();
82
83 if (DataTypeCanUseMemcpy(dtype)) {
84 int64 offset = 0;
85 for (const Tensor& tensor : tensors) {
86 StringPiece from_data = tensor.tensor_data();
87 CHECK_LE(offset + from_data.size(), to_data.size());
88 memcpy(const_cast<char*>(to_data.data()) + offset, from_data.data(),
89 from_data.size());
90
91 offset += from_data.size();
92 }
93 } else {
94 if (dtype != DT_STRING) {
95 return errors::Internal("Unexpected data type");
96 }
97 string* to_strings =
98 reinterpret_cast<string*>(const_cast<char*>(to_data.data()));
99
100 int64 offset = 0;
101 for (const Tensor& tensor : tensors) {
102 auto from_strings = tensor.flat<string>();
103 CHECK_LE(offset + tensor.NumElements(), result->NumElements());
104 for (int i = 0; i < tensor.NumElements(); ++i) {
105 to_strings[offset + i] = from_strings(i);
106 }
107
108 offset += tensor.NumElements();
109 }
110 }
111
112 return Status::OK();
113 }
114
Split(const Tensor & tensor,const gtl::ArraySlice<int64> & sizes,std::vector<Tensor> * result)115 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
116 std::vector<Tensor>* result) {
117 if (tensor.dims() == 0) {
118 return errors::InvalidArgument("Cannot split a zero-dimensional tensor");
119 }
120 int64 total_size = 0;
121 for (int64 size : sizes) {
122 total_size += size;
123 }
124 if (total_size != tensor.dim_size(0)) {
125 return errors::InvalidArgument(
126 "The values in 'sizes' do not sum to the zeroth-dimension size of "
127 "'tensor'");
128 }
129
130 StringPiece from_data = tensor.tensor_data();
131
132 if (DataTypeCanUseMemcpy(tensor.dtype())) {
133 int64 offset = 0;
134 for (int64 size : sizes) {
135 TensorShape shape = tensor.shape();
136 shape.set_dim(0, size);
137 result->emplace_back(tensor.dtype(), shape);
138 Tensor* split = &(*result)[result->size() - 1];
139
140 // We use StringPiece as a convenient map over the tensor buffer,
141 // but we cast the type to get to the underlying buffer to do the
142 // copy.
143 StringPiece to_data = split->tensor_data();
144 CHECK_LE(offset + to_data.size(), from_data.size());
145 memcpy(const_cast<char*>(to_data.data()), from_data.data() + offset,
146 to_data.size());
147
148 offset += to_data.size();
149 }
150 } else {
151 if (tensor.dtype() != DT_STRING) {
152 return errors::Internal("Unexpected data type");
153 }
154 auto from_strings = tensor.flat<string>();
155
156 int64 offset = 0;
157 for (int64 size : sizes) {
158 TensorShape shape = tensor.shape();
159 shape.set_dim(0, size);
160 result->emplace_back(tensor.dtype(), shape);
161 Tensor& split = (*result)[result->size() - 1];
162 string* to_strings = reinterpret_cast<string*>(
163 const_cast<char*>(split.tensor_data().data()));
164
165 CHECK_LE(offset + split.NumElements(), tensor.NumElements());
166 for (int i = 0; i < split.NumElements(); ++i) {
167 to_strings[i] = from_strings(offset + i);
168 }
169
170 offset += split.NumElements();
171 }
172 }
173
174 return Status::OK();
175 }
176
177 namespace internal {
SetTensorProtoShape(std::vector<size_t> shape,TensorShapeProto * shape_proto)178 void SetTensorProtoShape(std::vector<size_t> shape,
179 TensorShapeProto* shape_proto) {
180 for (auto dim : shape) {
181 shape_proto->mutable_dim()->Add()->set_size(dim);
182 }
183 }
184
185 template <typename T>
CompressTensorContent(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)186 bool CompressTensorContent(float min_compression_ratio,
187 const TensorShape& shape, TensorProto* tensor) {
188 using TypeHelper = internal::TensorProtoHelper<T>;
189 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
190 const int64 num_tensor_values = shape.num_elements();
191 const int64 num_bytes = tensor->tensor_content().size();
192 const int64 num_raw_values = num_bytes / sizeof(T);
193 if (num_raw_values != num_tensor_values) {
194 // Invalid or too small.
195 return false;
196 }
197 int64 last_offset = num_bytes - 1;
198 int64 prev_offset = last_offset - sizeof(T);
199 // Inspect individual raw bytes sizeof(T) bytes apart in adjacent elements,
200 // starting from the end, to find the last pair of elements that are not
201 // identical.
202 while (prev_offset >= 0) {
203 if (tensor->tensor_content()[prev_offset] !=
204 tensor->tensor_content()[last_offset]) {
205 break;
206 }
207 --last_offset;
208 --prev_offset;
209 }
210 // Round up to the next whole number of element of type T.
211 const int64 new_num_values = last_offset / sizeof(T) + 1;
212 if (new_num_values * (is_complex<T>::value ? 2 : 1) * sizeof(FieldType) >
213 static_cast<int64>(num_bytes / min_compression_ratio)) {
214 return false;
215 }
216 // Copy values to truncated repeated field.
217 if (sizeof(FieldType) == sizeof(T)) {
218 FieldType* dst_ptr =
219 TypeHelper::AppendUninitialized(new_num_values, tensor);
220 port::CopySubrangeToArray(tensor->tensor_content(), 0,
221 new_num_values * sizeof(T),
222 reinterpret_cast<char*>(dst_ptr));
223 tensor->clear_tensor_content();
224 } else if (sizeof(T) > 1) {
225 // Copy raw bytes to temp array first, then cast.
226 gtl::InlinedVector<T, 64> tmp(new_num_values);
227 port::CopySubrangeToArray(tensor->tensor_content(), 0,
228 new_num_values * sizeof(T),
229 reinterpret_cast<char*>(tmp.data()));
230 tensor->clear_tensor_content();
231 const T* begin = tmp.begin();
232 const T* end = tmp.end();
233 TypeHelper::AddValues(begin, end, tensor);
234 } else {
235 // Copy and cast, one byte at a time.
236 for (int64 i = 0; i < new_num_values; ++i) {
237 char c = tensor->tensor_content()[i];
238 TypeHelper::AddValue(static_cast<T>(c), tensor);
239 }
240 tensor->clear_tensor_content();
241 }
242 return true;
243 }
244
245 template <typename T>
PackedValuesNotEqual(T a,T b)246 inline bool PackedValuesNotEqual(T a, T b) {
247 return a != b;
248 }
249 template <>
PackedValuesNotEqual(float a,float b)250 inline bool PackedValuesNotEqual(float a, float b) {
251 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
252 }
253 template <>
PackedValuesNotEqual(double a,double b)254 inline bool PackedValuesNotEqual(double a, double b) {
255 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
256 }
257 template <typename RealType>
PackedValuesNotEqual(const std::complex<RealType> & a,const std::complex<RealType> & b)258 inline bool PackedValuesNotEqual(const std::complex<RealType>& a,
259 const std::complex<RealType>& b) {
260 return PackedValuesNotEqual(a.real(), b.real()) ||
261 PackedValuesNotEqual(a.imag(), b.imag());
262 }
263
264 template <typename T>
CompressRepeatedField(float min_compression_ratio,const TensorShape & shape,TensorProto * tensor)265 bool CompressRepeatedField(float min_compression_ratio,
266 const TensorShape& shape, TensorProto* tensor) {
267 using TypeHelper = internal::TensorProtoHelper<T>;
268 using FieldType = typename internal::TensorProtoHelper<T>::FieldType;
269 const int64 num_tensor_values = shape.num_elements();
270 // Notice that for complex types the tensor is stored as an array of up to
271 // 2 * num_tensor_values real values (real and imaginary parts), possibly
272 // truncated.
273 const int64 num_proto_values = TypeHelper::NumValues(*tensor);
274 if (num_proto_values != num_tensor_values) {
275 // Already compressed or invalid.
276 return false;
277 }
278 const T last_value = TypeHelper::GetValue(num_proto_values - 1, *tensor);
279 int64 last_index = 0;
280 for (int64 i = num_proto_values - 2; i >= 0 && last_index == 0; --i) {
281 const T cur_value = TypeHelper::GetValue(i, *tensor);
282 if (PackedValuesNotEqual(cur_value, last_value)) {
283 last_index = i + 1;
284 }
285 }
286 const int64 num_truncated_proto_values = last_index + 1;
287 const int64 num_bytes_as_field =
288 num_truncated_proto_values * sizeof(FieldType);
289 const int64 num_bytes_as_tensor_content = num_tensor_values * sizeof(T);
290 const int64 num_bytes_before = num_proto_values * sizeof(FieldType);
291 if (std::min(num_bytes_as_field, num_bytes_as_tensor_content) >
292 static_cast<int64>(num_bytes_before / min_compression_ratio)) {
293 return false;
294 }
295 if (num_bytes_as_field <= num_bytes_as_tensor_content) {
296 TypeHelper::Truncate(num_truncated_proto_values, tensor);
297 } else {
298 gtl::InlinedVector<T, 64> tmp(num_tensor_values);
299 TypeHelper::CopyValues(tmp.begin(), *tensor);
300 TypeHelper::Truncate(0, tensor);
301 port::CopyFromArray(tensor->mutable_tensor_content(),
302 reinterpret_cast<const char*>(tmp.data()),
303 num_bytes_as_tensor_content);
304 }
305 return true;
306 }
307
308 template <typename T>
CompressTensorProtoInPlaceImpl(int64 min_num_elements,float min_compression_ratio,TensorProto * tensor)309 bool CompressTensorProtoInPlaceImpl(int64 min_num_elements,
310 float min_compression_ratio,
311 TensorProto* tensor) {
312 const TensorShape shape(tensor->tensor_shape());
313 const int64 num_tensor_values = shape.num_elements();
314 if (num_tensor_values < min_num_elements) {
315 return false;
316 }
317 if (tensor->tensor_content().empty()) {
318 return CompressRepeatedField<T>(min_compression_ratio, shape, tensor);
319 } else {
320 return CompressTensorContent<T>(min_compression_ratio, shape, tensor);
321 }
322 return true;
323 }
324
325 } // namespace internal
326
327 #define HANDLE_COMPRESS_CASE(TF_TYPE) \
328 case TF_TYPE: \
329 return internal::CompressTensorProtoInPlaceImpl< \
330 EnumToDataType<TF_TYPE>::Type>(min_num_elements, \
331 min_compression_ratio, tensor); \
332 break
333
CompressTensorProtoInPlace(int64 min_num_elements,float min_compression_ratio,TensorProto * tensor)334 bool CompressTensorProtoInPlace(int64 min_num_elements,
335 float min_compression_ratio,
336 TensorProto* tensor) {
337 switch (tensor->dtype()) {
338 HANDLE_COMPRESS_CASE(DT_FLOAT);
339 HANDLE_COMPRESS_CASE(DT_DOUBLE);
340 HANDLE_COMPRESS_CASE(DT_COMPLEX64);
341 HANDLE_COMPRESS_CASE(DT_COMPLEX128);
342 HANDLE_COMPRESS_CASE(DT_UINT8);
343 HANDLE_COMPRESS_CASE(DT_INT8);
344 HANDLE_COMPRESS_CASE(DT_UINT16);
345 HANDLE_COMPRESS_CASE(DT_INT16);
346 HANDLE_COMPRESS_CASE(DT_UINT32);
347 HANDLE_COMPRESS_CASE(DT_INT32);
348 HANDLE_COMPRESS_CASE(DT_UINT64);
349 HANDLE_COMPRESS_CASE(DT_INT64);
350 HANDLE_COMPRESS_CASE(DT_BOOL);
351 HANDLE_COMPRESS_CASE(DT_QUINT8);
352 HANDLE_COMPRESS_CASE(DT_QINT8);
353 HANDLE_COMPRESS_CASE(DT_QUINT16);
354 HANDLE_COMPRESS_CASE(DT_QINT16);
355 HANDLE_COMPRESS_CASE(DT_QINT32);
356 HANDLE_COMPRESS_CASE(DT_HALF);
357 HANDLE_COMPRESS_CASE(DT_BFLOAT16);
358 default:
359 return false;
360 }
361 }
362
363 #undef HANDLE_COMPRESS_CASE
364
365 } // namespace tensor
366 } // namespace tensorflow
367