• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
18 
19 #include <algorithm>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/type_traits.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace tensor {
31 
32 // DeepCopy returns a tensor whose contents are a deep copy of the
33 // contents of 'other'.  This function is intended only for
34 // convenience, not speed.
35 //
36 // REQUIRES: 'other' must point to data stored in CPU memory.
37 // REQUIRES: 'other' must be a Tensor of a copy-able type if
38 //           'other' is not appropriately memory-aligned.
39 Tensor DeepCopy(const Tensor& other);
40 
41 // Deep copies input to output.  This function is similar to above, but assumes
42 // that the memory for the output has already been allocated.
43 void DeepCopy(const Tensor& input, Tensor* output);
44 
45 // Concatenates 'tensors' into a single tensor, along their 0th dimension.
46 //
47 // REQUIRES: All members of 'tensors' must have the same data type parameter.
48 // REQUIRES: Each member of 'tensors' must have at least one dimension.
49 // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
50 // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
51 //           is not appropriately memory-aligned.
52 Status Concat(const gtl::ArraySlice<Tensor>& tensors,
53               Tensor* result) TF_MUST_USE_RESULT;
54 
55 // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
56 // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
57 //
58 // REQUIRES: 'tensor' must have at least one dimension.
59 // REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
60 // REQUIRES: 'tensor' must point to data stored in CPU memory.
61 // REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
62 //           appropriately memory-aligned.
63 //
64 // Split() and Concat() are inverse operations.
65 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64_t>& sizes,
66              std::vector<Tensor>* result) TF_MUST_USE_RESULT;
67 
68 namespace internal {
69 void SetTensorProtoShape(std::vector<size_t> shape,
70                          TensorShapeProto* shape_proto);
71 
72 template <typename Type>
73 class TensorProtoFieldHelper : public std::false_type {};
74 
75 #define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME)                            \
76   template <>                                                                 \
77   class TensorProtoFieldHelper<TYPE> : public std::true_type {                \
78    public:                                                                    \
79     typedef decltype(                                                         \
80         std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType;            \
81     typedef decltype(                                                         \
82         std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType;     \
83     typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \
84         MutableRepeatedFieldType;                                             \
85     static MutableRepeatedFieldType GetMutableField(TensorProto* proto) {     \
86       return proto->mutable_##FIELDNAME##_val();                              \
87     }                                                                         \
88     static RepeatedFieldType& GetField(const TensorProto& proto) {            \
89       return proto.FIELDNAME##_val();                                         \
90     }                                                                         \
91   }
92 
93 // The argument pairs in the following macro instantiations encode the
94 // mapping from C++ type ($1) to repeated field name "$2_val" used for storing
95 // values in TensorProto. See tensorflow/core/framework/tensor.proto.
96 DEFINE_PROTO_FIELD_HELPER(float, float);
97 DEFINE_PROTO_FIELD_HELPER(double, double);
98 DEFINE_PROTO_FIELD_HELPER(int8, int);
99 DEFINE_PROTO_FIELD_HELPER(uint8, int);
100 DEFINE_PROTO_FIELD_HELPER(int16, int);
101 DEFINE_PROTO_FIELD_HELPER(uint16, int);
102 DEFINE_PROTO_FIELD_HELPER(int32, int);
103 DEFINE_PROTO_FIELD_HELPER(uint32, uint32);
104 DEFINE_PROTO_FIELD_HELPER(int64_t, int64);
105 DEFINE_PROTO_FIELD_HELPER(uint64, uint64);
106 DEFINE_PROTO_FIELD_HELPER(bool, bool);
107 DEFINE_PROTO_FIELD_HELPER(qint8, int);
108 DEFINE_PROTO_FIELD_HELPER(quint8, int);
109 DEFINE_PROTO_FIELD_HELPER(qint16, int);
110 DEFINE_PROTO_FIELD_HELPER(quint16, int);
111 DEFINE_PROTO_FIELD_HELPER(qint32, int);
112 DEFINE_PROTO_FIELD_HELPER(Eigen::half, half);
113 DEFINE_PROTO_FIELD_HELPER(bfloat16, half);
114 DEFINE_PROTO_FIELD_HELPER(complex64, scomplex);
115 DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex);
116 
117 #undef DEFINE_PROTO_HELPER
118 
119 template <typename T>
120 struct CopyHelper {
121   template <typename SrcIter, typename DstIter>
ToArrayCopyHelper122   static void ToArray(SrcIter begin, SrcIter end, DstIter dst) {
123     using SrcType = typename std::iterator_traits<SrcIter>::value_type;
124     using DstType = typename std::iterator_traits<DstIter>::value_type;
125     std::transform(begin, end, dst, [](const SrcType& x) -> DstType {
126       return static_cast<DstType>(x);
127     });
128   }
129   template <typename SrcIter>
ToArrayCopyHelper130   static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) {
131     std::copy(begin, end, dst);
132   }
133   template <typename SrcIter, typename DstIter>
FromArrayCopyHelper134   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
135     ToArray(begin, end, dst);
136   }
137 };
138 
139 // Overloads for Eigen::half and bfloat16 that are 16 bits in size but are
140 // stored in an int32 field.
141 template <>
142 struct CopyHelper<Eigen::half> {
143   template <typename SrcIter>
144   static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) {
145     std::transform(begin, end, dst, [](int x) -> Eigen::half {
146       return Eigen::numext::bit_cast<Eigen::half>(static_cast<uint16>(x));
147     });
148   }
149   template <typename SrcIter, typename DstIter>
150   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
151     std::transform(begin, end, dst, [](Eigen::half h) -> int {
152       return static_cast<int>(Eigen::numext::bit_cast<uint16>(h));
153     });
154   }
155 };
156 
157 template <>
158 struct CopyHelper<bfloat16> {
159   template <typename SrcIter>
160   static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) {
161     std::transform(begin, end, dst, [](int x) -> bfloat16 {
162       return Eigen::numext::bit_cast<bfloat16>(static_cast<uint16>(x));
163     });
164   }
165   template <typename SrcIter, typename DstIter>
166   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
167     std::transform(begin, end, dst, [](bfloat16 bf16) -> int {
168       return static_cast<int>(Eigen::numext::bit_cast<uint16>(bf16));
169     });
170   }
171 };
172 
173 // Overloads for complex types that store real and imaginary parts
174 // at indices 2*i and 2*i+1 in float or double field.
175 template <typename RealType>
176 struct CopyHelper<std::complex<RealType>> {
177   template <typename SrcIter>
178   static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) {
179     RealType* real_dst = reinterpret_cast<RealType*>(dst);
180     std::copy(begin, end, real_dst);
181   }
182 
183   template <typename SrcIter, typename DstIter>
184   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
185     size_t n = std::distance(begin, end);
186     const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin));
187     std::copy_n(real_begin, 2 * n, dst);
188   }
189 };
190 
191 // Helper class to extract and insert values into TensorProto represented as
192 // repeated fields.
193 template <typename T>
194 class TensorProtoHelper : public std::true_type {
195  public:
196   using FieldHelper = TensorProtoFieldHelper<T>;
197   using FieldType = typename TensorProtoFieldHelper<T>::FieldType;
198 
199   static DataType GetDataType() { return DataTypeToEnum<T>::value; }
200 
201   // Returns the number of values of type T encoded in the proto.
202   static size_t NumValues(const TensorProto& proto) {
203     size_t raw_size = FieldHelper::GetField(proto).size();
204     return is_complex<T>::value ? raw_size / 2 : raw_size;
205   }
206 
207   static void AddValue(const T& value, TensorProto* proto) {
208     const T* val_ptr = &value;
209     AddValues(val_ptr, val_ptr + 1, proto);
210   }
211 
212   static T GetValue(size_t index, const TensorProto& proto) {
213     const size_t stride = is_complex<T>::value ? 2 : 1;
214     T val;
215     CopyHelper<T>::ToArray(
216         FieldHelper::GetField(proto).begin() + stride * index,
217         FieldHelper::GetField(proto).begin() + stride * (index + 1), &val);
218     return val;
219   }
220 
221   template <typename IterType>
222   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
223     size_t n = std::distance(begin, end);
224     FieldType* dst = AppendUninitialized(n, proto);
225     CopyHelper<T>::FromArray(begin, end, dst);
226   }
227 
228   template <typename IterType>
229   static void CopyValues(IterType dst, const TensorProto& proto) {
230     CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(),
231                            FieldHelper::GetField(proto).end(), dst);
232   }
233 
234   static void Truncate(size_t new_size, TensorProto* proto) {
235     if (is_complex<T>::value) new_size *= 2;
236     FieldHelper::GetMutableField(proto)->Truncate(new_size);
237   }
238 
239   static FieldType* AppendUninitialized(size_t n, TensorProto* proto) {
240     if (is_complex<T>::value) n *= 2;
241     auto* field = FieldHelper::GetMutableField(proto);
242     field->Reserve(field->size() + n);
243     return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n));
244   }
245 };
246 
247 // Specialization for string.
248 template <>
249 class TensorProtoHelper<string> : public std::true_type {
250  public:
251   static DataType GetDataType() { return DataType::DT_STRING; }
252   static void AddValue(const string& value, TensorProto* proto) {
253     *proto->mutable_string_val()->Add() = value;
254   }
255   template <typename IterType>
256   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
257     for (IterType it = begin; it != end; ++it) {
258       AddValue(*it, proto);
259     }
260   }
261   template <typename IterType>
262   static void CopyToTensorContent(IterType begin, IterType end,
263                                   TensorProto* proto) {
264     AddValues(begin, end, proto);
265   }
266 };
267 
268 }  // namespace internal
269 
270 // Creates a 'TensorProto' with specified shape and values.
271 // The dtype and a field to represent data values of the returned 'TensorProto'
272 // are determined based on type of the 'values' parameter.
273 template <typename Type>
274 typename std::enable_if<internal::TensorProtoHelper<Type>::value,
275                         TensorProto>::type
276 CreateTensorProto(const std::vector<Type>& values,
277                   const std::vector<size_t>& shape) {
278   TensorProto tensor;
279   TensorShapeProto tensor_shape_proto;
280   internal::SetTensorProtoShape(shape, &tensor_shape_proto);
281   if (TensorShape(tensor_shape_proto).num_elements() != values.size()) {
282     LOG(ERROR) << "Shape and number of values (" << values.size()
283                << ") are incompatible.";
284     return tensor;
285   }
286   using TypeHelper = internal::TensorProtoHelper<Type>;
287   tensor.set_dtype(TypeHelper::GetDataType());
288   tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto);
289   TypeHelper::AddValues(values.begin(), values.end(), &tensor);
290   return tensor;
291 }
292 
293 // Converts values in tensor to run-length encoded compressed form.
294 //
295 // The elements of a tensor can be stored in a TensorProto in one of the
296 // following two forms:
297 // 1. As a raw byte string in the field `tensor_content` containing the
298 //    serialized in-memory representation of the tensor.
299 // 2. As values of a repeated field depending on the datatype, e.g. that
300 //    values of a DT_FLOAT tensor would be stored in the repeated field
301 //    `float_val`.
302 // Storage scheme 2 may use a simple form of run-length encoding to compress
303 // data: If the values contains a tail of identical values, the repeated field
304 // will be truncated such that the number of values in the repeated field is
305 // less than the number of elements implied by the field`tensor_shape`. The
306 // original tensor can be recovered by repeating the final value in the repeated
307 // field.
308 //
309 // The TensorProto will be compressed if a) the tensor contains at least
310 // min_num_elements elements and b) the compressed tensor proto is would be at
311 // most the size of the original tensor proto divided by min_compression_ratio.
312 //
313 // Returns true if the tensor was compressed.
314 bool CompressTensorProtoInPlace(int64_t min_num_elements,
315                                 float min_compression_ratio,
316                                 TensorProto* tensor);
317 
318 inline bool CompressTensorProtoInPlace(TensorProto* tensor) {
319   static const int64_t kDefaultMinNumElements = 64;
320   static const float kDefaultMinCompressionRatio = 2.0f;
321   return CompressTensorProtoInPlace(kDefaultMinNumElements,
322                                     kDefaultMinCompressionRatio, tensor);
323 }
324 
325 // Make a TensorShape from the contents of shape_t. Shape_t must be a
326 // 1-dimensional tensor of type int32 or int64.
327 Status MakeShape(const Tensor& shape_t, TensorShape* out);
328 
329 }  // namespace tensor
330 }  // namespace tensorflow
331 
332 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
333