• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
18 
19 #include <algorithm>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/type_traits.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace tensor {
31 
32 // DeepCopy returns a tensor whose contents are a deep copy of the
33 // contents of 'other'.  This function is intended only for
34 // convenience, not speed.
35 //
36 // REQUIRES: 'other' must point to data stored in CPU memory.
37 // REQUIRES: 'other' must be a Tensor of a copy-able type if
38 //           'other' is not appropriately memory-aligned.
39 Tensor DeepCopy(const Tensor& other);
40 
41 // Deep copies input to output.  This function is similar to above, but assumes
42 // that the memory for the output has already been allocated.
43 void DeepCopy(const Tensor& input, Tensor* output);
44 
45 // Concatenates 'tensors' into a single tensor, along their 0th dimension.
46 //
47 // REQUIRES: All members of 'tensors' must have the same data type parameter.
48 // REQUIRES: Each member of 'tensors' must have at least one dimension.
49 // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
50 // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
51 //           is not appropriately memory-aligned.
52 Status Concat(const gtl::ArraySlice<Tensor>& tensors,
53               Tensor* result) TF_MUST_USE_RESULT;
54 
55 // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
56 // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
57 //
58 // REQUIRES: 'tensor' must have at least one dimension.
59 // REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
60 // REQUIRES: 'tensor' must point to data stored in CPU memory.
61 // REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
62 //           appropriately memory-aligned.
63 //
64 // Split() and Concat() are inverse operations.
65 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
66              std::vector<Tensor>* result) TF_MUST_USE_RESULT;
67 
68 namespace internal {
69 void SetTensorProtoShape(std::vector<size_t> shape,
70                          TensorShapeProto* shape_proto);
71 
72 template <typename Type>
73 class TensorProtoFieldHelper : public std::false_type {};
74 
75 #define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME)                            \
76   template <>                                                                 \
77   class TensorProtoFieldHelper<TYPE> : public std::true_type {                \
78    public:                                                                    \
79     typedef decltype(                                                         \
80         std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType;            \
81     typedef decltype(                                                         \
82         std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType;     \
83     typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \
84         MutableRepeatedFieldType;                                             \
85     static MutableRepeatedFieldType GetMutableField(TensorProto* proto) {     \
86       return proto->mutable_##FIELDNAME##_val();                              \
87     }                                                                         \
88     static RepeatedFieldType& GetField(const TensorProto& proto) {            \
89       return proto.FIELDNAME##_val();                                         \
90     }                                                                         \
91   }
92 
93 // The argument pairs in the following macro instantiations encode the
94 // mapping from C++ type ($1) to repeated field name "$2_val" used for storing
95 // values in TensorProto. See tensorflow/core/framework/tensor.proto.
96 DEFINE_PROTO_FIELD_HELPER(float, float);
97 DEFINE_PROTO_FIELD_HELPER(double, double);
98 DEFINE_PROTO_FIELD_HELPER(int8, int);
99 DEFINE_PROTO_FIELD_HELPER(uint8, int);
100 DEFINE_PROTO_FIELD_HELPER(int16, int);
101 DEFINE_PROTO_FIELD_HELPER(uint16, int);
102 DEFINE_PROTO_FIELD_HELPER(int32, int);
103 DEFINE_PROTO_FIELD_HELPER(uint32, uint32);
104 DEFINE_PROTO_FIELD_HELPER(int64, int64);
105 DEFINE_PROTO_FIELD_HELPER(uint64, uint64);
106 DEFINE_PROTO_FIELD_HELPER(bool, bool);
107 DEFINE_PROTO_FIELD_HELPER(qint8, int);
108 DEFINE_PROTO_FIELD_HELPER(quint8, int);
109 DEFINE_PROTO_FIELD_HELPER(qint16, int);
110 DEFINE_PROTO_FIELD_HELPER(quint16, int);
111 DEFINE_PROTO_FIELD_HELPER(qint32, int);
112 DEFINE_PROTO_FIELD_HELPER(Eigen::half, half);
113 DEFINE_PROTO_FIELD_HELPER(bfloat16, half);
114 DEFINE_PROTO_FIELD_HELPER(complex64, scomplex);
115 DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex);
116 
117 #undef DEFINE_PROTO_HELPER
118 
119 template <typename T>
120 struct CopyHelper {
121   template <typename SrcIter, typename DstIter>
ToArrayCopyHelper122   static void ToArray(SrcIter begin, SrcIter end, DstIter dst) {
123     using SrcType = typename std::iterator_traits<SrcIter>::value_type;
124     using DstType = typename std::iterator_traits<DstIter>::value_type;
125     std::transform(begin, end, dst, [](const SrcType& x) -> DstType {
126       return static_cast<DstType>(x);
127     });
128   }
129   template <typename SrcIter>
ToArrayCopyHelper130   static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) {
131     std::copy(begin, end, dst);
132   }
133   template <typename SrcIter, typename DstIter>
FromArrayCopyHelper134   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
135     ToArray(begin, end, dst);
136   }
137 };
138 
139 // Overloads for Eigen::half and bfloat16 that are 16 bits in size but are
140 // stored in an int32 field.
141 template <>
142 struct CopyHelper<Eigen::half> {
143   template <typename SrcIter>
144   static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) {
145     std::transform(begin, end, dst, [](int x) -> Eigen::half {
146       Eigen::half h;
147       h.x = static_cast<uint16>(x);
148       return h;
149     });
150   }
151   template <typename SrcIter, typename DstIter>
152   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
153     std::transform(begin, end, dst,
154                    [](Eigen::half h) -> int { return static_cast<int>(h.x); });
155   }
156 };
157 
158 template <>
159 struct CopyHelper<bfloat16> {
160   template <typename SrcIter>
161   static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) {
162     std::transform(begin, end, dst, [](int x) -> bfloat16 {
163       bfloat16 bf16;
164       bf16.value = static_cast<uint16>(x);
165       return bf16;
166     });
167   }
168   template <typename SrcIter, typename DstIter>
169   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
170     std::transform(begin, end, dst, [](bfloat16 bf16) -> int {
171       return static_cast<int>(bf16.value);
172     });
173   }
174 };
175 
176 // Overloads for complex types that store real and imaginary parts
177 // at indices 2*i and 2*i+1 in float or double field.
178 template <typename RealType>
179 struct CopyHelper<std::complex<RealType>> {
180   template <typename SrcIter>
181   static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) {
182     RealType* real_dst = reinterpret_cast<RealType*>(dst);
183     std::copy(begin, end, real_dst);
184   }
185 
186   template <typename SrcIter, typename DstIter>
187   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
188     size_t n = std::distance(begin, end);
189     const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin));
190     std::copy_n(real_begin, 2 * n, dst);
191   }
192 };
193 
194 // Helper class to extract and insert values into TensorProto represented as
195 // repeated fields.
196 template <typename T>
197 class TensorProtoHelper : public std::true_type {
198  public:
199   using FieldHelper = TensorProtoFieldHelper<T>;
200   using FieldType = typename TensorProtoFieldHelper<T>::FieldType;
201 
202   static DataType GetDataType() { return DataTypeToEnum<T>::value; }
203 
204   // Returns the number of values of type T encoded in the proto.
205   static size_t NumValues(const TensorProto& proto) {
206     size_t raw_size = FieldHelper::GetField(proto).size();
207     return is_complex<T>::value ? raw_size / 2 : raw_size;
208   }
209 
210   static void AddValue(const T& value, TensorProto* proto) {
211     const T* val_ptr = &value;
212     AddValues(val_ptr, val_ptr + 1, proto);
213   }
214 
215   static T GetValue(size_t index, const TensorProto& proto) {
216     const size_t stride = is_complex<T>::value ? 2 : 1;
217     T val;
218     CopyHelper<T>::ToArray(
219         FieldHelper::GetField(proto).begin() + stride * index,
220         FieldHelper::GetField(proto).begin() + stride * (index + 1), &val);
221     return val;
222   }
223 
224   template <typename IterType>
225   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
226     size_t n = std::distance(begin, end);
227     FieldType* dst = AppendUninitialized(n, proto);
228     CopyHelper<T>::FromArray(begin, end, dst);
229   }
230 
231   template <typename IterType>
232   static void CopyValues(IterType dst, const TensorProto& proto) {
233     CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(),
234                            FieldHelper::GetField(proto).end(), dst);
235   }
236 
237   static void Truncate(size_t new_size, TensorProto* proto) {
238     if (is_complex<T>::value) new_size *= 2;
239     FieldHelper::GetMutableField(proto)->Truncate(new_size);
240   }
241 
242   static FieldType* AppendUninitialized(size_t n, TensorProto* proto) {
243     if (is_complex<T>::value) n *= 2;
244     auto* field = FieldHelper::GetMutableField(proto);
245     field->Reserve(field->size() + n);
246     return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n));
247   }
248 };
249 
250 // Specialization for string.
251 template <>
252 class TensorProtoHelper<string> : public std::true_type {
253  public:
254   static DataType GetDataType() { return DataType::DT_STRING; }
255   static void AddValue(const string& value, TensorProto* proto) {
256     *proto->mutable_string_val()->Add() = value;
257   }
258   template <typename IterType>
259   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
260     for (IterType it = begin; it != end; ++it) {
261       AddValue(*it, proto);
262     }
263   }
264   template <typename IterType>
265   static void CopyToTensorContent(IterType begin, IterType end,
266                                   TensorProto* proto) {
267     AddValues(begin, end, proto);
268   }
269 };
270 
271 }  // namespace internal
272 
273 // Creates a 'TensorProto' with specified shape and values.
274 // The dtype and a field to represent data values of the returned 'TensorProto'
275 // are determined based on type of the 'values' parameter.
276 template <typename Type>
277 typename std::enable_if<internal::TensorProtoHelper<Type>::value,
278                         TensorProto>::type
279 CreateTensorProto(const std::vector<Type>& values,
280                   const std::vector<size_t>& shape) {
281   TensorProto tensor;
282   TensorShapeProto tensor_shape_proto;
283   internal::SetTensorProtoShape(shape, &tensor_shape_proto);
284   if (TensorShape(tensor_shape_proto).num_elements() != values.size()) {
285     LOG(ERROR) << "Shape and number of values (" << values.size()
286                << ") are incompatible.";
287     return tensor;
288   }
289   using TypeHelper = internal::TensorProtoHelper<Type>;
290   tensor.set_dtype(TypeHelper::GetDataType());
291   tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto);
292   TypeHelper::AddValues(values.begin(), values.end(), &tensor);
293   return tensor;
294 }
295 
296 // Converts values in tensor to run-length encoded compressed form.
297 //
298 // The elements of a tensor can be stored in a TensorProto in one of the
299 // following two forms:
300 // 1. As a raw byte string in the field `tensor_content` containing the
301 //    serialized in-memory representation of the tensor.
302 // 2. As values of a repeated field depending on the datatype, e.g. that
303 //    values of a DT_FLOAT tensor would be stored in the repeated field
304 //    `float_val`.
305 // Storage scheme 2 may use a simple form of run-length encoding to compress
306 // data: If the values contains a tail of identical values, the repeated field
307 // will be truncated such that the number of values in the repeated field is
308 // less than the number of elements implied by the field`tensor_shape`. The
309 // original tensor can be recovered by repeating the final value in the repeated
310 // field.
311 //
312 // The TensorProto will be compressed if a) the tensor contains at least
313 // min_num_elements elements and b) the compressed tensor proto is would be at
314 // most the size of the original tensor proto divided by min_compression_ratio.
315 //
316 // Returns true if the tensor was compressed.
317 bool CompressTensorProtoInPlace(int64 min_num_elements,
318                                 float min_compression_ratio,
319                                 TensorProto* tensor);
320 
321 inline bool CompressTensorProtoInPlace(TensorProto* tensor) {
322   static const int64 kDefaultMinNumElements = 64;
323   static const float kDefaultMinCompressionRatio = 2.0f;
324   return CompressTensorProtoInPlace(kDefaultMinNumElements,
325                                     kDefaultMinCompressionRatio, tensor);
326 }
327 
328 // Make a TensorShape from the contents of shape_t. Shape_t must be a
329 // 1-dimensional tensor of type int32 or int64.
330 Status MakeShape(const Tensor& shape_t, TensorShape* out);
331 
332 }  // namespace tensor
333 }  // namespace tensorflow
334 
335 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
336