• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
18 
19 #include <algorithm>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/type_traits.h"
26 #include "tensorflow/core/platform/protobuf.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace tensorflow {
30 namespace tensor {
31 
32 // DeepCopy returns a tensor whose contents are a deep copy of the
33 // contents of 'other'.  This function is intended only for
34 // convenience, not speed.
35 //
36 // REQUIRES: 'other' must point to data stored in CPU memory.
37 // REQUIRES: 'other' must be a Tensor of a copy-able type if
38 //           'other' is not appropriately memory-aligned.
39 Tensor DeepCopy(const Tensor& other);
40 
41 // Concatenates 'tensors' into a single tensor, along their 0th dimension.
42 //
43 // REQUIRES: All members of 'tensors' must have the same data type parameter.
44 // REQUIRES: Each member of 'tensors' must have at least one dimension.
45 // REQUIRES: Each member of 'tensors' must point to data stored in CPU memory.
46 // REQUIRES: Each member of 'tensors' must be a Tensor of a copy-able type if it
47 //           is not appropriately memory-aligned.
48 Status Concat(const gtl::ArraySlice<Tensor>& tensors,
49               Tensor* result) TF_MUST_USE_RESULT;
50 
51 // Splits 'tensor' into 'sizes.size()' individual tensors, along the 0th
52 // dimension. The ith output tensor has 0th-dimension size 'sizes[i]'.
53 //
54 // REQUIRES: 'tensor' must have at least one dimension.
55 // REQUIRES: 'tensor.dim_size(0)' must equal the sum of the elements of 'sizes'.
56 // REQUIRES: 'tensor' must point to data stored in CPU memory.
57 // REQUIRES: 'tensor' must be a Tensor of a copy-able type if it is not
58 //           appropriately memory-aligned.
59 //
60 // Split() and Concat() are inverse operations.
61 Status Split(const Tensor& tensor, const gtl::ArraySlice<int64>& sizes,
62              std::vector<Tensor>* result) TF_MUST_USE_RESULT;
63 
64 namespace internal {
65 void SetTensorProtoShape(std::vector<size_t> shape,
66                          TensorShapeProto* shape_proto);
67 
68 template <typename Type>
69 class TensorProtoFieldHelper : public std::false_type {};
70 
71 #define DEFINE_PROTO_FIELD_HELPER(TYPE, FIELDNAME)                            \
72   template <>                                                                 \
73   class TensorProtoFieldHelper<TYPE> : public std::true_type {                \
74    public:                                                                    \
75     typedef decltype(                                                         \
76         std::declval<TensorProto>().FIELDNAME##_val(0)) FieldType;            \
77     typedef decltype(                                                         \
78         std::declval<TensorProto>().FIELDNAME##_val()) RepeatedFieldType;     \
79     typedef decltype(std::declval<TensorProto>().mutable_##FIELDNAME##_val()) \
80         MutableRepeatedFieldType;                                             \
81     static MutableRepeatedFieldType GetMutableField(TensorProto* proto) {     \
82       return proto->mutable_##FIELDNAME##_val();                              \
83     }                                                                         \
84     static RepeatedFieldType& GetField(const TensorProto& proto) {            \
85       return proto.FIELDNAME##_val();                                         \
86     }                                                                         \
87   }
88 
89 // The argument pairs in the following macro instantiations encode the
90 // mapping from C++ type ($1) to repeated field name "$2_val" used for storing
91 // values in TensorProto. See tensorflow/core/framework/tensor.proto.
92 DEFINE_PROTO_FIELD_HELPER(float, float);
93 DEFINE_PROTO_FIELD_HELPER(double, double);
94 DEFINE_PROTO_FIELD_HELPER(int8, int);
95 DEFINE_PROTO_FIELD_HELPER(uint8, int);
96 DEFINE_PROTO_FIELD_HELPER(int16, int);
97 DEFINE_PROTO_FIELD_HELPER(uint16, int);
98 DEFINE_PROTO_FIELD_HELPER(int32, int);
99 DEFINE_PROTO_FIELD_HELPER(uint32, uint32);
100 DEFINE_PROTO_FIELD_HELPER(int64, int64);
101 DEFINE_PROTO_FIELD_HELPER(uint64, uint64);
102 DEFINE_PROTO_FIELD_HELPER(bool, bool);
103 DEFINE_PROTO_FIELD_HELPER(qint8, int);
104 DEFINE_PROTO_FIELD_HELPER(quint8, int);
105 DEFINE_PROTO_FIELD_HELPER(qint16, int);
106 DEFINE_PROTO_FIELD_HELPER(quint16, int);
107 DEFINE_PROTO_FIELD_HELPER(qint32, int);
108 DEFINE_PROTO_FIELD_HELPER(Eigen::half, half);
109 DEFINE_PROTO_FIELD_HELPER(bfloat16, half);
110 DEFINE_PROTO_FIELD_HELPER(complex64, scomplex);
111 DEFINE_PROTO_FIELD_HELPER(complex128, dcomplex);
112 
113 #undef DEFINE_PROTO_HELPER
114 
115 template <typename T>
116 struct CopyHelper {
117   template <typename SrcIter, typename DstIter>
ToArrayCopyHelper118   static void ToArray(SrcIter begin, SrcIter end, DstIter dst) {
119     using SrcType = typename std::iterator_traits<SrcIter>::value_type;
120     using DstType = typename std::iterator_traits<DstIter>::value_type;
121     std::transform(begin, end, dst, [](const SrcType& x) -> DstType {
122       return static_cast<DstType>(x);
123     });
124   }
125   template <typename SrcIter>
ToArrayCopyHelper126   static void ToArray(SrcIter begin, SrcIter end, SrcIter dst) {
127     std::copy(begin, end, dst);
128   }
129   template <typename SrcIter, typename DstIter>
FromArrayCopyHelper130   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
131     ToArray(begin, end, dst);
132   }
133 };
134 
135 // Overloads for Eigen::half and bfloat16 that are 16 bits in size but are
136 // stored in an int32 field.
137 template <>
138 struct CopyHelper<Eigen::half> {
139   template <typename SrcIter>
140   static void ToArray(SrcIter begin, SrcIter end, Eigen::half* dst) {
141     std::transform(begin, end, dst, [](int x) -> Eigen::half {
142       Eigen::half h;
143       h.x = static_cast<uint16>(x);
144       return h;
145     });
146   }
147   template <typename SrcIter, typename DstIter>
148   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
149     std::transform(begin, end, dst,
150                    [](Eigen::half h) -> int { return static_cast<int>(h.x); });
151   }
152 };
153 
154 template <>
155 struct CopyHelper<bfloat16> {
156   template <typename SrcIter>
157   static void ToArray(SrcIter begin, SrcIter end, bfloat16* dst) {
158     std::transform(begin, end, dst, [](int x) -> bfloat16 {
159       bfloat16 bf16;
160       bf16.value = static_cast<uint16>(x);
161       return bf16;
162     });
163   }
164   template <typename SrcIter, typename DstIter>
165   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
166     std::transform(begin, end, dst, [](bfloat16 bf16) -> int {
167       return static_cast<int>(bf16.value);
168     });
169   }
170 };
171 
172 // Overloads for complex types that store real and imaginary parts
173 // at indices 2*i and 2*i+1 in float or double field.
174 template <typename RealType>
175 struct CopyHelper<std::complex<RealType>> {
176   template <typename SrcIter>
177   static void ToArray(SrcIter begin, SrcIter end, std::complex<RealType>* dst) {
178     using SrcType = typename std::iterator_traits<SrcIter>::value_type;
179     RealType* real_dst = reinterpret_cast<RealType*>(dst);
180     std::copy(begin, end, real_dst);
181   }
182 
183   template <typename SrcIter, typename DstIter>
184   static void FromArray(SrcIter begin, SrcIter end, DstIter dst) {
185     using DstType = typename std::iterator_traits<DstIter>::value_type;
186     size_t n = std::distance(begin, end);
187     const RealType* real_begin = reinterpret_cast<const RealType*>(&(*begin));
188     std::copy_n(real_begin, 2 * n, dst);
189   }
190 };
191 
192 // Helper class to extract and insert values into TensorProto represented as
193 // repeated fields.
194 template <typename T>
195 class TensorProtoHelper : public std::true_type {
196  public:
197   using FieldHelper = TensorProtoFieldHelper<T>;
198   using FieldType = typename TensorProtoFieldHelper<T>::FieldType;
199 
200   static DataType GetDataType() { return DataTypeToEnum<T>::value; }
201 
202   // Returns the number of values of type T encoded in the proto.
203   static size_t NumValues(const TensorProto& proto) {
204     size_t raw_size = FieldHelper::GetField(proto).size();
205     return is_complex<T>::value ? raw_size / 2 : raw_size;
206   }
207 
208   static void AddValue(const T& value, TensorProto* proto) {
209     const T* val_ptr = &value;
210     AddValues(val_ptr, val_ptr + 1, proto);
211   }
212 
213   static T GetValue(size_t index, const TensorProto& proto) {
214     T val;
215     if (is_complex<T>::value) index *= 2;
216     CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin() + index,
217                            FieldHelper::GetField(proto).begin() + index + 1,
218                            &val);
219     return val;
220   }
221 
222   template <typename IterType>
223   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
224     size_t n = std::distance(begin, end);
225     FieldType* dst = AppendUninitialized(n, proto);
226     CopyHelper<T>::FromArray(begin, end, dst);
227   }
228 
229   template <typename IterType>
230   static void CopyValues(IterType dst, const TensorProto& proto) {
231     CopyHelper<T>::ToArray(FieldHelper::GetField(proto).begin(),
232                            FieldHelper::GetField(proto).end(), dst);
233   }
234 
235   static void Truncate(size_t new_size, TensorProto* proto) {
236     if (is_complex<T>::value) new_size *= 2;
237     FieldHelper::GetMutableField(proto)->Truncate(new_size);
238   }
239 
240   static FieldType* AppendUninitialized(size_t n, TensorProto* proto) {
241     if (is_complex<T>::value) n *= 2;
242     auto* field = FieldHelper::GetMutableField(proto);
243     field->Reserve(field->size() + n);
244     return reinterpret_cast<FieldType*>(field->AddNAlreadyReserved(n));
245   }
246 };
247 
248 // Specialization for string.
249 template <>
250 class TensorProtoHelper<string> : public std::true_type {
251  public:
252   static DataType GetDataType() { return DataType::DT_STRING; }
253   static void AddValue(const string& value, TensorProto* proto) {
254     *proto->mutable_string_val()->Add() = value;
255   }
256   template <typename IterType>
257   static void AddValues(IterType begin, IterType end, TensorProto* proto) {
258     for (IterType it = begin; it != end; ++it) {
259       AddValue(*it, proto);
260     }
261   }
262   template <typename IterType>
263   static void CopyToTensorContent(IterType begin, IterType end,
264                                   TensorProto* proto) {
265     AddValues(begin, end, proto);
266   }
267 };
268 
269 }  // namespace internal
270 
271 // Creates a 'TensorProto' with specified shape and values.
272 // The dtype and a field to represent data values of the returned 'TensorProto'
273 // are determined based on type of the 'values' parameter.
274 template <typename Type>
275 typename std::enable_if<internal::TensorProtoHelper<Type>::value,
276                         TensorProto>::type
277 CreateTensorProto(const std::vector<Type>& values,
278                   const std::vector<size_t>& shape) {
279   TensorProto tensor;
280   TensorShapeProto tensor_shape_proto;
281   internal::SetTensorProtoShape(shape, &tensor_shape_proto);
282   if (TensorShape(tensor_shape_proto).num_elements() != values.size()) {
283     LOG(ERROR) << "Shape and number of values (" << values.size()
284                << ") are incompatible.";
285     return tensor;
286   }
287   using TypeHelper = internal::TensorProtoHelper<Type>;
288   tensor.set_dtype(TypeHelper::GetDataType());
289   tensor.mutable_tensor_shape()->Swap(&tensor_shape_proto);
290   TypeHelper::AddValues(values.begin(), values.end(), &tensor);
291   return tensor;
292 }
293 
294 // Converts values in tensor to run-length encoded compressed form.
295 //
296 // The elements of a tensor can be stored in a TensorProto in one of the
297 // following two forms:
298 // 1. As a raw byte string in the field `tensor_content` containing the
299 //    serialized in-memory representation of the tensor.
300 // 2. As values of a repeated field depending on the datatype, e.g. that
301 //    values of a DT_FLOAT tensor would be stored in the repeated field
302 //    `float_val`.
303 // Storage scheme 2 may use a simple form of run-length encoding to compress
304 // data: If the values contains a tail of identical values, the repeated field
305 // will be truncated such that the number of values in the repeated field is
306 // less than the number of elements implied by the field`tensor_shape`. The
307 // original tensor can be recovered by repeating the final value in the repeated
308 // field.
309 //
310 // The TensorProto will be compressed if a) the tensor contains at least
311 // min_num_elements elements and b) the compressed tensor proto is would be at
312 // most the size of the original tensor proto divided by min_compression_ratio.
313 //
314 // Returns true if the tensor was compressed.
315 bool CompressTensorProtoInPlace(int64 min_num_elements,
316                                 float min_compression_ratio,
317                                 TensorProto* tensor);
318 
319 inline bool CompressTensorProtoInPlace(TensorProto* tensor) {
320   static const int64 kDefaultMinNumElements = 64;
321   static const float kDefaultMinCompressionRatio = 2.0f;
322   return CompressTensorProtoInPlace(kDefaultMinNumElements,
323                                     kDefaultMinCompressionRatio, tensor);
324 }
325 
326 }  // namespace tensor
327 }  // namespace tensorflow
328 
329 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
330